diff --git a/tests/tpch/test_correctness.py b/tests/tpch/test_correctness.py index 53700e759f..58188691d7 100644 --- a/tests/tpch/test_correctness.py +++ b/tests/tpch/test_correctness.py @@ -77,7 +77,7 @@ def verify_result(result: pd.DataFrame, query: int, answer_dir: pathlib.Path): 5, 6, 7, - pytest.param(8, marks=pytest.mark.skip(reason="Not implemented")), + 8, 9, pytest.param(10, marks=pytest.mark.xfail(reason="Result is wrong")), 11, diff --git a/tests/tpch/test_dask.py b/tests/tpch/test_dask.py index 3a9596ba26..9dfaa58db5 100644 --- a/tests/tpch/test_dask.py +++ b/tests/tpch/test_dask.py @@ -327,6 +327,66 @@ def test_query_7(client, dataset_path, fs): return result +def test_query_8(client, dataset_path, fs): + var1 = datetime.strptime("1995-01-01", "%Y-%m-%d") + var2 = datetime.strptime("1997-01-01", "%Y-%m-%d") + + supplier = dd.read_parquet(dataset_path + "supplier", filesystem=fs) + lineitem = dd.read_parquet(dataset_path + "lineitem", filesystem=fs) + orders = dd.read_parquet(dataset_path + "orders", filesystem=fs) + customer = dd.read_parquet(dataset_path + "customer", filesystem=fs) + nation = dd.read_parquet(dataset_path + "nation", filesystem=fs) + region = dd.read_parquet(dataset_path + "region", filesystem=fs) + part = dd.read_parquet(dataset_path + "part", filesystem=fs) + + part = part[part["p_type"] == "ECONOMY ANODIZED STEEL"][["p_partkey"]] + lineitem["volume"] = lineitem["l_extendedprice"] * (1.0 - lineitem["l_discount"]) + total = part.merge(lineitem, left_on="p_partkey", right_on="l_partkey", how="inner") + + total = total.merge( + supplier, left_on="l_suppkey", right_on="s_suppkey", how="inner" + ) + + orders = orders[(orders["o_orderdate"] >= var1) & (orders["o_orderdate"] < var2)] + orders["o_year"] = orders["o_orderdate"].dt.year + total = total.merge( + orders, left_on="l_orderkey", right_on="o_orderkey", how="inner" + ) + + total = total.merge( + customer, left_on="o_custkey", right_on="c_custkey", how="inner" + ) + + n1_filtered = nation[["n_nationkey", "n_regionkey"]] + total = total.merge( + n1_filtered, left_on="c_nationkey", right_on="n_nationkey", how="inner" + ) + + n2_filtered = nation[["n_nationkey", "n_name"]].rename(columns={"n_name": "nation"}) + total = total.merge( + n2_filtered, left_on="s_nationkey", right_on="n_nationkey", how="inner" + ) + + region = region[region["r_name"] == "AMERICA"][["r_regionkey"]] + total = total.merge( + region, left_on="n_regionkey", right_on="r_regionkey", how="inner" + ) + + mkt_brazil = ( + total[total["nation"] == "BRAZIL"].groupby("o_year").volume.sum().reset_index() + ) + mkt_total = total.groupby("o_year").volume.sum().reset_index() + + final = mkt_total.merge( + mkt_brazil, left_on="o_year", right_on="o_year", suffixes=("_mkt", "_brazil") + ) + final["mkt_share"] = final.volume_brazil / final.volume_mkt + final = final.sort_values(by=["o_year"], ascending=[True])[["o_year", "mkt_share"]] + result = final.compute() + + return result + + @pytest.mark.shuffle_p2p def test_query_9(client, dataset_path, fs): """