From 120e7af6466190b754cf3026c685a5d31561da90 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 27 Feb 2023 22:02:37 +0100 Subject: [PATCH] fix assert_run_python_script on Windows (#7346) --- test/common_utils.py | 17 ++++++++++------- test/test_transforms.py | 5 ----- test/test_transforms_v2.py | 9 --------- 3 files changed, 10 insertions(+), 21 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 2f74f3686c3..697b6f6e4ca 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -844,17 +844,20 @@ def get_closeness_kwargs(self, test_id, *, dtype, device): def assert_run_python_script(source_code): """Utility to check assertions in an independent Python subprocess. + The script provided in the source code should return 0 and not print - anything on stderr or stdout. Taken from scikit-learn test utils. - source_code (str): The Python source code to execute. + anything on stderr or stdout. Modified from scikit-learn test utils. + + Args: + source_code (str): The Python source code to execute. """ - with tempfile.NamedTemporaryFile(mode="wb") as f: - f.write(source_code.encode()) - f.flush() + with get_tmp_dir() as root: + path = pathlib.Path(root) / "main.py" + with open(path, "w") as file: + file.write(source_code) - cmd = [sys.executable, f.name] try: - out = check_output(cmd, stderr=STDOUT) + out = check_output([sys.executable, str(path)], stderr=STDOUT) except CalledProcessError as e: raise RuntimeError(f"script errored with output:\n{e.output.decode()}") if out != b"": diff --git a/test/test_transforms.py b/test/test_transforms.py index b6eccba421b..03b385e9edd 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -2,7 +2,6 @@ import os import random import re -import sys import textwrap import warnings from functools import partial @@ -2279,10 +2278,6 @@ def test_random_grayscale_with_grayscale_input(): ), ) @pytest.mark.parametrize("from_private", (True, False)) -@pytest.mark.skipif( - sys.platform in ("win32", "cygwin"), - reason="assert_run_python_script is broken on Windows. Possible fix in https://github.com/pytorch/vision/pull/7346", -) def test_functional_deprecation_warning(import_statement, from_private): if from_private: import_statement = import_statement.replace("functional", "_functional") diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 9fe7bbf51f2..f5ca976963a 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2,7 +2,6 @@ import pathlib import random import re -import sys import textwrap import warnings from collections import defaultdict @@ -2102,10 +2101,6 @@ def test_sanitize_bounding_boxes_errors(): ), ) @pytest.mark.parametrize("call_disable_warning", (True, False)) -@pytest.mark.skipif( - sys.platform in ("win32", "cygwin"), - reason="assert_run_python_script is broken on Windows. Possible fix in https://github.com/pytorch/vision/pull/7346", -) def test_warnings_v2_namespaces(import_statement, call_disable_warning): if call_disable_warning: source = f""" @@ -2125,10 +2120,6 @@ def test_warnings_v2_namespaces(import_statement, call_disable_warning): assert_run_python_script(textwrap.dedent(source)) -@pytest.mark.skipif( - sys.platform in ("win32", "cygwin"), - reason="assert_run_python_script is broken on Windows. Possible fix in https://github.com/pytorch/vision/pull/7346", -) def test_no_warnings_v1_namespace(): source = """ import warnings