Skip to content

Commit cd6bdcd

Browse files
Merge pull request #15 from awslabs/better_s3_test
Upload s3 artifacts during python profiling tests
2 parents 2eb38ae + 9b758be commit cd6bdcd

File tree

1 file changed

+27
-6
lines changed

1 file changed

+27
-6
lines changed

tests/profiler/core/test_python_profiler.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import time
77

88
# Third Party
9+
import boto3
910
import pytest
1011

1112
# First Party
13+
from smdebug.core.access_layer.utils import is_s3
1214
from smdebug.profiler.analysis.python_profile_analysis import PyinstrumentAnalysis, cProfileAnalysis
1315
from smdebug.profiler.profiler_constants import (
1416
CONVERT_TO_MICROSECS,
@@ -51,9 +53,9 @@ def reset_python_profiler_dir(cprofile_dir, pyinstrument_dir):
5153
shutil.rmtree(pyinstrument_dir, ignore_errors=True)
5254

5355

54-
@pytest.fixture()
56+
@pytest.fixture(scope="session")
5557
def bucket_prefix():
56-
return "s3://smdebug-testing/resources/python_profile"
58+
return f"s3://smdebug-testing/resources/python_profile/{int(time.time())}"
5759

5860

5961
@pytest.fixture(scope="session")
@@ -91,6 +93,15 @@ def time_function():
9193
python_profiler.stop_profiling()
9294

9395

96+
def _upload_s3_folder(bucket, key, folder):
97+
s3_client = boto3.client("s3")
98+
for root, _, files in os.walk(folder):
99+
for file in files:
100+
dir = os.path.basename(root)
101+
full_key = os.path.join(key, dir, file)
102+
s3_client.upload_file(os.path.join(root, file), bucket, full_key)
103+
104+
94105
@pytest.mark.parametrize("steps", [(1, 2), (1, 5)])
95106
def test_cprofile_profiling(cprofile_python_profiler, steps, cprofile_dir):
96107
"""
@@ -132,8 +143,10 @@ def test_pyinstrument_profiling(pyinstrument_python_profiler, steps, pyinstrumen
132143
assert json.load(f) # validate output file
133144

134145

135-
@pytest.mark.parametrize("s3", [True, False])
136-
def test_cprofile_analysis(cprofile_python_profiler, cprofile_dir, bucket_prefix, s3):
146+
@pytest.mark.parametrize("s3", [False, True])
147+
def test_cprofile_analysis(
148+
cprofile_python_profiler, cprofile_dir, bucket_prefix, test_framework, s3
149+
):
137150
"""
138151
This test is meant to test that the cProfile analysis retrieves the correct step's stats based on the specified
139152
interval. Stats are either retrieved from s3 or generated manually through python profiling.
@@ -148,6 +161,9 @@ def test_cprofile_analysis(cprofile_python_profiler, cprofile_dir, bucket_prefix
148161
# Do analysis and use those stats.
149162
_analysis_set_up(cprofile_python_profiler)
150163
python_profile_analysis = cProfileAnalysis(local_profile_dir=cprofile_dir)
164+
_, bucket, prefix = is_s3(bucket_prefix)
165+
key = os.path.join(prefix, "framework", test_framework, CPROFILE_NAME)
166+
_upload_s3_folder(bucket, key, cprofile_dir)
151167

152168
# Test that step_function call is recorded in received stats, but not time_function.
153169
assert len(python_profile_analysis.python_profile_stats) == 2
@@ -170,8 +186,10 @@ def test_cprofile_analysis(cprofile_python_profiler, cprofile_dir, bucket_prefix
170186
assert all(["step_function" not in stat.function_name for stat in function_stats_list])
171187

172188

173-
@pytest.mark.parametrize("s3", [True, False])
174-
def test_pyinstrument_analysis(pyinstrument_python_profiler, pyinstrument_dir, bucket_prefix, s3):
189+
@pytest.mark.parametrize("s3", [False, True])
190+
def test_pyinstrument_analysis(
191+
pyinstrument_python_profiler, pyinstrument_dir, test_framework, bucket_prefix, s3
192+
):
175193
"""
176194
This test is meant to test that the pyinstrument analysis retrieves the correct step's stats based on the specified
177195
interval. Stats are either retrieved from s3 or generated manually through python profiling.
@@ -186,6 +204,9 @@ def test_pyinstrument_analysis(pyinstrument_python_profiler, pyinstrument_dir, b
186204
# Do analysis and use those stats.
187205
_analysis_set_up(pyinstrument_python_profiler)
188206
python_profile_analysis = PyinstrumentAnalysis(local_profile_dir=pyinstrument_dir)
207+
_, bucket, prefix = is_s3(bucket_prefix)
208+
key = os.path.join(prefix, "framework", test_framework, PYINSTRUMENT_NAME)
209+
_upload_s3_folder(bucket, key, pyinstrument_dir)
189210

190211
# Test that step_function call is recorded in received stats, but not time_function.
191212
assert len(python_profile_analysis.python_profile_stats) == 2

0 commit comments

Comments
 (0)