diff --git a/tests/tpch/test_dask.py b/tests/tpch/test_dask.py index 9631b09b16..25da951057 100644 --- a/tests/tpch/test_dask.py +++ b/tests/tpch/test_dask.py @@ -304,83 +304,54 @@ def test_query_8(client, dataset_path, fs): var1 = datetime.strptime("1995-01-01", "%Y-%m-%d") var2 = datetime.strptime("1997-01-01", "%Y-%m-%d") - supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs) - lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs) - orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs) - customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs) - nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs) - region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs) - part_ds = dd.read_parquet(dataset_path + "part", filesystem=fs) + supplier = dd.read_parquet(dataset_path + "supplier", filesystem=fs) + lineitem = dd.read_parquet(dataset_path + "lineitem", filesystem=fs) + orders = dd.read_parquet(dataset_path + "orders", filesystem=fs) + customer = dd.read_parquet(dataset_path + "customer", filesystem=fs) + nation = dd.read_parquet(dataset_path + "nation", filesystem=fs) + region = dd.read_parquet(dataset_path + "region", filesystem=fs) + part = dd.read_parquet(dataset_path + "part", filesystem=fs) - part_filtered = part_ds[part_ds["p_type"] == "ECONOMY ANODIZED STEEL"][ - ["p_partkey"] - ] + part = part[part["p_type"] == "ECONOMY ANODIZED STEEL"][["p_partkey"]] - lineitem_filtered = lineitem_ds[["l_partkey", "l_suppkey", "l_orderkey"]] - lineitem_filtered["volume"] = lineitem_ds["l_extendedprice"] * ( - 1.0 - lineitem_ds["l_discount"] - ) - total = part_filtered.merge( - lineitem_filtered, - left_on="p_partkey", - right_on="l_partkey", - how="inner", - )[["l_suppkey", "l_orderkey", "volume"]] + lineitem["volume"] = lineitem["l_extendedprice"] * (1.0 - lineitem["l_discount"]) + total = part.merge(lineitem, left_on="p_partkey", right_on="l_partkey", how="inner") - supplier_filtered = supplier_ds[["s_suppkey", "s_nationkey"]] total = total.merge( - supplier_filtered, - left_on="l_suppkey", - right_on="s_suppkey", - how="inner", - )[["l_orderkey", "volume", "s_nationkey"]] + supplier, left_on="l_suppkey", right_on="s_suppkey", how="inner" + ) - orders_filtered = orders_ds[ - (orders_ds["o_orderdate"] >= var1) & (orders_ds["o_orderdate"] < var2) - ] + orders = orders[(orders["o_orderdate"] >= var1) & (orders["o_orderdate"] < var2)] - orders_filtered["o_year"] = orders_filtered["o_orderdate"].dt.year - orders_filtered = orders_filtered[["o_orderkey", "o_custkey", "o_year"]] + orders["o_year"] = orders["o_orderdate"].dt.year total = total.merge( - orders_filtered, - left_on="l_orderkey", - right_on="o_orderkey", - how="inner", - )[["volume", "s_nationkey", "o_custkey", "o_year"]] + orders, left_on="l_orderkey", right_on="o_orderkey", how="inner" + ) - customer_filtered = customer_ds[["c_custkey", "c_nationkey"]] total = total.merge( - customer_filtered, - left_on="o_custkey", - right_on="c_custkey", - how="inner", - )[["volume", "s_nationkey", "o_year", "c_nationkey"]] - - n1_filtered = nation_ds[["n_nationkey", "n_regionkey"]] - n2_filtered = nation_ds[["n_nationkey", "n_name"]].rename( - columns={"n_name": "nation"} + customer, left_on="o_custkey", right_on="c_custkey", how="inner" ) + + n1_filtered = nation[["n_nationkey", "n_regionkey"]] + n2_filtered = nation[["n_nationkey", "n_name"]].rename(columns={"n_name": "nation"}) total = total.merge( n1_filtered, left_on="c_nationkey", right_on="n_nationkey", how="inner", - )[["volume", "s_nationkey", "o_year", "n_regionkey"]] + ) total = total.merge( n2_filtered, left_on="s_nationkey", right_on="n_nationkey", how="inner", - )[["volume", "o_year", "n_regionkey", "nation"]] + ) - region_filtered = region_ds[region_ds["r_name"] == "AMERICA"][["r_regionkey"]] + region_filtered = region[region["r_name"] == "AMERICA"][["r_regionkey"]] total = total.merge( - region_filtered, - left_on="n_regionkey", - right_on="r_regionkey", - how="inner", - )[["volume", "o_year", "nation"]] + region_filtered, left_on="n_regionkey", right_on="r_regionkey", how="inner" + ) mkt_brazil = ( total[total["nation"] == "BRAZIL"].groupby("o_year").volume.sum().reset_index() @@ -390,6 +361,5 @@ def test_query_8(client, dataset_path, fs): mkt_brazil, left_on="o_year", right_on="o_year", suffixes=("_mkt", "_brazil") ) final["mkt_share"] = final.volume_brazil / final.volume_mkt - total = total.sort_values(by=["o_year"], ascending=[True]) - - total.compute() + final = final.sort_values(by=["o_year"], ascending=[True]) + final.compute()