Skip to content

Commit 79c88d8

Browse files
refactor: update benchmark logic to use dataset path and warehouse si… (#1771)
* refactor: update benchmark logic to use dataset path and warehouse size, update error catching for loading data to snowflake, update env example * refactor: remove odd logic in run_on_emb logic and fixed path generation for Embucket --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent f2e5f81 commit 79c88d8

File tree

4 files changed

+55
-57
lines changed

4 files changed

+55
-57
lines changed

benchmark/benchmark.py

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@
2424
logger = logging.getLogger(__name__)
2525

2626

27-
def get_results_path(system: SystemType, benchmark_type: str, scale_factor: str,
28-
warehouse_or_instance: str, run_number: Optional[int] = None) -> str:
27+
def get_results_path(system: SystemType, benchmark_type: str, dataset_path: str,
28+
instance: str, warehouse_size: str = None, run_number: Optional[int] = None) -> str:
2929
"""Generate path for storing benchmark results."""
3030
if system == SystemType.SNOWFLAKE:
31-
base_path = f"result/snowflake_{benchmark_type}_results/{scale_factor}/{warehouse_or_instance}"
31+
# Use warehouse size in the path instead of warehouse name
32+
base_path = f"result/snowflake_{benchmark_type}_results/{dataset_path}/{warehouse_size}"
3233
elif system == SystemType.EMBUCKET:
33-
base_path = f"result/embucket_{benchmark_type}_results/{scale_factor}/{warehouse_or_instance}"
34+
base_path = f"result/embucket_{benchmark_type}_results/{dataset_path}/{instance}"
3435
else:
3536
raise ValueError(f"Unsupported system: {system}")
3637

@@ -149,7 +150,7 @@ def run_on_sf(cursor, warehouse, tpch_queries):
149150
return results
150151

151152

152-
def run_on_emb(cursor, tpch_queries):
153+
def run_on_emb(tpch_queries):
153154
"""Run TPCH queries on Embucket with container restart before each query."""
154155
docker_manager = create_docker_manager()
155156
executed_query_ids = []
@@ -271,11 +272,11 @@ def run_snowflake_benchmark(run_number: int):
271272
# Get benchmark configuration from environment variables
272273
benchmark_type = os.environ.get("BENCHMARK_TYPE", "tpch")
273274
warehouse = os.environ["SNOWFLAKE_WAREHOUSE"]
274-
dataset = os.environ["DATASET_NAME"]
275-
scale_factor = os.environ["DATASET_SCALE_FACTOR"]
275+
warehouse_size = os.environ["SNOWFLAKE_WAREHOUSE_SIZE"]
276+
dataset_path = os.environ["DATASET_PATH"]
276277

277278
logger.info(f"Starting Snowflake {benchmark_type} benchmark run {run_number}")
278-
logger.info(f"Dataset: {dataset}, Schema: {scale_factor}, Warehouse: {warehouse}")
279+
logger.info(f"Dataset: {dataset_path}, Warehouse: {warehouse}, Size: {warehouse_size}")
279280

280281
# Get queries and run benchmark
281282
queries = get_queries_for_benchmark(benchmark_type, for_embucket=False)
@@ -286,9 +287,9 @@ def run_snowflake_benchmark(run_number: int):
286287
# Disable query result caching for benchmark
287288
sf_cursor.execute("ALTER SESSION SET USE_CACHED_RESULT = FALSE;")
288289

289-
sf_results = run_on_sf(sf_cursor,warehouse, queries)
290+
sf_results = run_on_sf(sf_cursor, warehouse, queries)
290291

291-
results_path = get_results_path(SystemType.SNOWFLAKE, benchmark_type, scale_factor, warehouse, run_number)
292+
results_path = get_results_path(SystemType.SNOWFLAKE, benchmark_type, dataset_path, warehouse, warehouse_size, run_number)
292293
os.makedirs(os.path.dirname(results_path), exist_ok=True)
293294
save_results_to_csv(sf_results, filename=results_path, system=SystemType.SNOWFLAKE)
294295

@@ -298,50 +299,49 @@ def run_snowflake_benchmark(run_number: int):
298299
sf_connection.close()
299300

300301
# Check if we have 3 CSV files ready and calculate averages if so
301-
results_dir = get_results_path(SystemType.SNOWFLAKE, benchmark_type, scale_factor, warehouse)
302+
results_dir = get_results_path(SystemType.SNOWFLAKE, benchmark_type, dataset_path, warehouse, warehouse_size)
302303
csv_files = glob.glob(os.path.join(results_dir, "snowflake_results_run_*.csv"))
303304
if len(csv_files) == 3:
304305
logger.info("Found 3 CSV files. Calculating averages...")
305306
calculate_benchmark_averages(
306-
scale_factor,
307-
warehouse,
307+
dataset_path,
308+
warehouse_size, # Pass warehouse size instead of name
308309
SystemType.SNOWFLAKE,
309310
benchmark_type
310311
)
311312

312313
return sf_results
313314

314315

316+
315317
def run_embucket_benchmark(run_number: int):
316318
"""Run benchmark on Embucket with container restarts."""
317319
# Get benchmark configuration from environment variables
318320
benchmark_type = os.environ.get("BENCHMARK_TYPE", "tpch")
319321
instance = os.environ["EMBUCKET_INSTANCE"]
320-
dataset = os.environ.get("EMBUCKET_DATASET", os.environ["DATASET_NAME"])
321-
scale_factor = os.environ["DATASET_SCALE_FACTOR"]
322+
dataset_path = os.environ.get("EMBUCKET_DATASET_PATH", os.environ["DATASET_PATH"])
322323

323324
logger.info(f"Starting Embucket {benchmark_type} benchmark run {run_number}")
324-
logger.info(f"Instance: {instance}, Dataset: {dataset}, Scale Factor: {scale_factor}")
325+
logger.info(f"Instance: {instance}, Dataset: {dataset_path}")
325326

326327
# Get queries and docker manager
327328
queries = get_queries_for_benchmark(benchmark_type, for_embucket=True)
328-
docker_manager = create_docker_manager()
329329

330330
# Run benchmark
331-
emb_results = run_on_emb(docker_manager, queries)
331+
emb_results = run_on_emb(queries)
332332

333-
results_path = get_results_path(SystemType.EMBUCKET, benchmark_type, scale_factor, instance, run_number)
333+
results_path = get_results_path(SystemType.EMBUCKET, benchmark_type, dataset_path, instance, run_number=run_number)
334334
os.makedirs(os.path.dirname(results_path), exist_ok=True)
335335
save_results_to_csv(emb_results, filename=results_path, system=SystemType.EMBUCKET)
336336
logger.info(f"Embucket benchmark results saved to: {results_path}")
337337

338338
# Check if we have 3 CSV files ready and calculate averages
339-
results_dir = get_results_path(SystemType.EMBUCKET, benchmark_type, scale_factor, instance)
339+
results_dir = get_results_path(SystemType.EMBUCKET, benchmark_type, dataset_path, instance)
340340
csv_files = glob.glob(os.path.join(results_dir, "embucket_results_run_*.csv"))
341341
if len(csv_files) == 3:
342342
logger.info("Found 3 CSV files. Calculating averages...")
343343
calculate_benchmark_averages(
344-
scale_factor,
344+
dataset_path,
345345
instance,
346346
SystemType.EMBUCKET,
347347
benchmark_type
@@ -398,8 +398,7 @@ def parse_args():
398398
parser.add_argument("--platform", choices=["snowflake", "embucket", "both"], default="both")
399399
parser.add_argument("--runs", type=int, default=3)
400400
parser.add_argument("--benchmark-type", choices=["tpch", "tpcds"], default=os.environ.get("BENCHMARK_TYPE", "tpch"))
401-
parser.add_argument("--dataset-name", help="Override the DATASET_NAME environment variable")
402-
parser.add_argument("--scale-factor", help="Override the DATASET_SCALE_FACTOR environment variable")
401+
parser.add_argument("--dataset-path", help="Override the DATASET_PATH environment variable")
403402
return parser.parse_args()
404403

405404

@@ -410,11 +409,8 @@ def parse_args():
410409
if args.benchmark_type != os.environ.get("BENCHMARK_TYPE", "tpch"):
411410
os.environ["BENCHMARK_TYPE"] = args.benchmark_type
412411

413-
if args.dataset_name:
414-
os.environ["DATASET_NAME"] = args.dataset_name
415-
416-
if args.scale_factor:
417-
os.environ["DATASET_SCALE_FACTOR"] = args.scale_factor
412+
if args.dataset_path:
413+
os.environ["DATASET_PATH"] = args.dataset_path
418414

419415
# Execute benchmarks based on platform selection
420416
if args.platform == "snowflake":

benchmark/data_preparation.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,12 @@ def create_tables(cursor, system):
2323
cursor.execute(ddl_sql.strip())
2424

2525

26-
def upload_parquet_to_snowflake_tables(cursor, dataset, dataset_scale_factor):
26+
def upload_parquet_to_snowflake_tables(cursor, dataset_path):
2727
"""Upload parquet files to Snowflake tables from S3 stage."""
2828
table_names = get_table_names(fully_qualified_names_for_embucket=False)
2929
for table_name in table_names.values():
3030
print(f"Loading data into Snowflake table {table_name}...")
31-
# Load data directly from the S3 stage
32-
s3_path = f"s3://embucket-testdata/{dataset}/{dataset_scale_factor}/{table_name}.parquet"
31+
s3_path = f"s3://embucket-testdata/{dataset_path}/{table_name}.parquet"
3332
cursor.execute(f"""
3433
COPY INTO {table_name}
3534
FROM '{s3_path}'
@@ -38,9 +37,13 @@ def upload_parquet_to_snowflake_tables(cursor, dataset, dataset_scale_factor):
3837
FILE_FORMAT = (TYPE = PARQUET)
3938
MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE;
4039
""")
40+
result = cursor.fetchall()
41+
if result and result[0][0] == 'Copy executed with 0 files processed.':
42+
raise RuntimeError(f"No files processed for {table_name}. Check S3 path: {s3_path}")
4143

4244

43-
def upload_parquet_to_embucket_tables(cursor, dataset, dataset_scale_factor):
45+
46+
def upload_parquet_to_embucket_tables(cursor, dataset_path):
4447
"""Upload parquet files to Embucket tables using COPY INTO."""
4548
# Get fully qualified table names using the unified logic
4649
table_names = get_table_names(fully_qualified_names_for_embucket=True)
@@ -50,31 +53,31 @@ def upload_parquet_to_embucket_tables(cursor, dataset, dataset_scale_factor):
5053
bare_table_name = qualified_table_name.split('.')[-1]
5154
print(f"Loading data into Embucket table {qualified_table_name}...")
5255

53-
copy_sql = f"COPY INTO {qualified_table_name} FROM 's3://embucket-testdata/{dataset}/{dataset_scale_factor}/{bare_table_name}.parquet' FILE_FORMAT = (TYPE = PARQUET)"
56+
copy_sql = f"COPY INTO {qualified_table_name} FROM 's3://embucket-testdata/{dataset_path}/{bare_table_name}.parquet' FILE_FORMAT = (TYPE = PARQUET)"
5457
cursor.execute(copy_sql)
5558

5659

57-
def prepare_data_for_embucket(dataset, dataset_scale_factor):
60+
def prepare_data_for_embucket(dataset_path):
5861
"""Prepare data for Embucket: generate data, create tables, and load data."""
5962
# Connect to Embucket
6063
cursor = create_embucket_connection().cursor()
6164
# Create tables
6265
create_tables(cursor, SystemType.EMBUCKET)
6366
# Load data into Embucket tables
64-
upload_parquet_to_embucket_tables(cursor, dataset, dataset_scale_factor)
67+
upload_parquet_to_embucket_tables(cursor, dataset_path)
6568

6669
cursor.close()
6770
print("Embucket data preparation completed successfully.")
6871

6972

70-
def prepare_data_for_snowflake(dataset, dataset_scale_factor):
73+
def prepare_data_for_snowflake(dataset_path):
7174
"""Prepare data, create tables, and load data for Snowflake"""
7275
# Connect to Snowflake
7376
cursor = create_snowflake_connection().cursor()
7477
# Create tables
7578
create_tables(cursor, SystemType.SNOWFLAKE)
7679
# Load data into Snowflake tables
77-
upload_parquet_to_snowflake_tables(cursor, dataset, dataset_scale_factor)
80+
upload_parquet_to_snowflake_tables(cursor, dataset_path)
7881

7982
cursor.close()
8083
print("Snowflake data preparation completed successfully.")
@@ -84,24 +87,19 @@ def prepare_data_for_snowflake(dataset, dataset_scale_factor):
8487
parser = argparse.ArgumentParser(description="Prepare data for Embucket/Snowflake benchmarks")
8588
parser.add_argument("--system", type=str, choices=["embucket", "snowflake", "both"],
8689
default="both", help="Which system to prepare data for")
87-
parser.add_argument("--dataset", type=str, default=os.environ.get("DATASET_NAME", "tpch"),
88-
help="Dataset name (default: from env or 'tpch')")
89-
parser.add_argument("--scale", type=str, default=os.environ.get("DATASET_SCALE_FACTOR", "01"),
90-
help="Dataset scale factor (default: from env or '1')")
90+
parser.add_argument("--dataset-path", type=str, default=os.environ.get("DATASET_PATH", "tpch/01"),
91+
help="Dataset path in format 'dataset/scale' (default: from env or 'tpch/01')")
9192

9293
args = parser.parse_args()
9394

94-
# Override environment variables if specified in args
95-
if args.dataset is not None:
96-
os.environ["DATASET_NAME"] = args.dataset
97-
98-
if args.scale is not None:
99-
os.environ["DATASET_SCALE_FACTOR"] = args.scale
95+
# Override environment variable if specified in args
96+
if args.dataset_path:
97+
os.environ["DATASET_PATH"] = args.dataset_path
10098

101-
print(f"Preparing data for dataset: {args.dataset}, scale: {args.scale}")
99+
print(f"Preparing data for dataset path: {args.dataset_path}")
102100

103-
# if args.system.lower() in ["embucket", "both"]:
104-
# prepare_data_for_embucket(args.dataset, args.scale)
101+
if args.system.lower() in ["embucket", "both"]:
102+
prepare_data_for_embucket(args.dataset_path)
105103

106104
if args.system.lower() in ["snowflake", "both"]:
107-
prepare_data_for_snowflake(args.dataset, args.scale)
105+
prepare_data_for_snowflake(args.dataset_path)

benchmark/env_example

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@ SNOWFLAKE_ACCOUNT=your_snowflake_account
22
SNOWFLAKE_USER=your_snowflake_user
33
SNOWFLAKE_PASSWORD=your_snowflake_password
44
SNOWFLAKE_DATABASE=benchmark_db
5-
SNOWFLAKE_WAREHOUSE=BENCHMARK_WH_XS
5+
SNOWFLAKE_WAREHOUSE=BENCHMARK_WH
66

77
EMBUCKET_INSTANCE=c7i_2xlarge
8+
SNOWFLAKE_WAREHOUSE_SIZE=XSMALL
89

910
BENCHMARK_TYPE=tpch
1011
DATASET_S3_BUCKET=embucket-testdata
11-
DATASET_NAME=tpch_data
12-
DATASET_SCALE_FACTOR=sf_01
12+
#dataset and scale factor path inside the s3 bucket
13+
DATASET_PATH=tpch/01
1314

1415
EMBUCKET_ACCOUNT=embucket
1516
EMBUCKET_USER=embucket

benchmark/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,13 @@ def create_embucket_connection():
3535
return conn
3636

3737

38-
3938
def create_snowflake_connection():
4039
"""Create Snowflake connection with environment-based config."""
4140
user = os.environ["SNOWFLAKE_USER"]
4241
password = os.environ["SNOWFLAKE_PASSWORD"]
4342
account = os.environ["SNOWFLAKE_ACCOUNT"]
44-
database = os.environ["DATASET_NAME"]
45-
schema = os.environ["DATASET_SCALE_FACTOR"]
43+
database = os.environ["SNOWFLAKE_DATABASE"]
44+
schema = os.environ["SNOWFLAKE_SCHEMA"]
4645
warehouse = os.environ["SNOWFLAKE_WAREHOUSE"]
4746

4847
if not all([user, password, account, database, schema, warehouse]):
@@ -59,6 +58,9 @@ def create_snowflake_connection():
5958

6059
conn = sf.connect(**connect_args)
6160

61+
conn.cursor().execute(f"CREATE OR REPLACE WAREHOUSE {warehouse} WITH WAREHOUSE_SIZE = '{os.environ['SNOWFLAKE_WAREHOUSE_SIZE']}';")
62+
conn.cursor().execute(f"USE WAREHOUSE {warehouse};")
63+
6264
conn.cursor().execute(f"CREATE DATABASE IF NOT EXISTS {database}")
6365
conn.cursor().execute(f"CREATE SCHEMA IF NOT EXISTS {schema}")
6466
conn.cursor().execute(f"USE SCHEMA {schema}")
@@ -67,3 +69,4 @@ def create_snowflake_connection():
6769
conn.cursor().execute("CREATE OR REPLACE TEMPORARY STAGE sf_prep_stage FILE_FORMAT = sf_parquet_format;")
6870

6971
return conn
72+

0 commit comments

Comments
 (0)