Skip to content

Commit

Permalink
Fixup: refactor / cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
milesgranger committed Oct 31, 2023
1 parent 1717751 commit 74431e7
Showing 1 changed file with 27 additions and 57 deletions.
84 changes: 27 additions & 57 deletions tests/tpch/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,83 +304,54 @@ def test_query_8(client, dataset_path, fs):
var1 = datetime.strptime("1995-01-01", "%Y-%m-%d")
var2 = datetime.strptime("1997-01-01", "%Y-%m-%d")

supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs)
lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs)
customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs)
nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs)
region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs)
part_ds = dd.read_parquet(dataset_path + "part", filesystem=fs)
supplier = dd.read_parquet(dataset_path + "supplier", filesystem=fs)
lineitem = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
orders = dd.read_parquet(dataset_path + "orders", filesystem=fs)
customer = dd.read_parquet(dataset_path + "customer", filesystem=fs)
nation = dd.read_parquet(dataset_path + "nation", filesystem=fs)
region = dd.read_parquet(dataset_path + "region", filesystem=fs)
part = dd.read_parquet(dataset_path + "part", filesystem=fs)

part_filtered = part_ds[part_ds["p_type"] == "ECONOMY ANODIZED STEEL"][
["p_partkey"]
]
part = part[part["p_type"] == "ECONOMY ANODIZED STEEL"][["p_partkey"]]

lineitem_filtered = lineitem_ds[["l_partkey", "l_suppkey", "l_orderkey"]]
lineitem_filtered["volume"] = lineitem_ds["l_extendedprice"] * (
1.0 - lineitem_ds["l_discount"]
)
total = part_filtered.merge(
lineitem_filtered,
left_on="p_partkey",
right_on="l_partkey",
how="inner",
)[["l_suppkey", "l_orderkey", "volume"]]
lineitem["volume"] = lineitem["l_extendedprice"] * (1.0 - lineitem["l_discount"])
total = part.merge(lineitem, left_on="p_partkey", right_on="l_partkey", how="inner")

supplier_filtered = supplier_ds[["s_suppkey", "s_nationkey"]]
total = total.merge(
supplier_filtered,
left_on="l_suppkey",
right_on="s_suppkey",
how="inner",
)[["l_orderkey", "volume", "s_nationkey"]]
supplier, left_on="l_suppkey", right_on="s_suppkey", how="inner"
)

orders_filtered = orders_ds[
(orders_ds["o_orderdate"] >= var1) & (orders_ds["o_orderdate"] < var2)
]
orders = orders[(orders["o_orderdate"] >= var1) & (orders["o_orderdate"] < var2)]

orders_filtered["o_year"] = orders_filtered["o_orderdate"].dt.year
orders_filtered = orders_filtered[["o_orderkey", "o_custkey", "o_year"]]
orders["o_year"] = orders["o_orderdate"].dt.year
total = total.merge(
orders_filtered,
left_on="l_orderkey",
right_on="o_orderkey",
how="inner",
)[["volume", "s_nationkey", "o_custkey", "o_year"]]
orders, left_on="l_orderkey", right_on="o_orderkey", how="inner"
)

customer_filtered = customer_ds[["c_custkey", "c_nationkey"]]
total = total.merge(
customer_filtered,
left_on="o_custkey",
right_on="c_custkey",
how="inner",
)[["volume", "s_nationkey", "o_year", "c_nationkey"]]

n1_filtered = nation_ds[["n_nationkey", "n_regionkey"]]
n2_filtered = nation_ds[["n_nationkey", "n_name"]].rename(
columns={"n_name": "nation"}
customer, left_on="o_custkey", right_on="c_custkey", how="inner"
)

n1_filtered = nation[["n_nationkey", "n_regionkey"]]
n2_filtered = nation[["n_nationkey", "n_name"]].rename(columns={"n_name": "nation"})
total = total.merge(
n1_filtered,
left_on="c_nationkey",
right_on="n_nationkey",
how="inner",
)[["volume", "s_nationkey", "o_year", "n_regionkey"]]
)

total = total.merge(
n2_filtered,
left_on="s_nationkey",
right_on="n_nationkey",
how="inner",
)[["volume", "o_year", "n_regionkey", "nation"]]
)

region_filtered = region_ds[region_ds["r_name"] == "AMERICA"][["r_regionkey"]]
region_filtered = region[region["r_name"] == "AMERICA"][["r_regionkey"]]
total = total.merge(
region_filtered,
left_on="n_regionkey",
right_on="r_regionkey",
how="inner",
)[["volume", "o_year", "nation"]]
region_filtered, left_on="n_regionkey", right_on="r_regionkey", how="inner"
)

mkt_brazil = (
total[total["nation"] == "BRAZIL"].groupby("o_year").volume.sum().reset_index()
Expand All @@ -390,6 +361,5 @@ def test_query_8(client, dataset_path, fs):
mkt_brazil, left_on="o_year", right_on="o_year", suffixes=("_mkt", "_brazil")
)
final["mkt_share"] = final.volume_brazil / final.volume_mkt
total = total.sort_values(by=["o_year"], ascending=[True])

total.compute()
final = final.sort_values(by=["o_year"], ascending=[True])
final.compute()

0 comments on commit 74431e7

Please sign in to comment.