Skip to content

Commit

Permalink
update example
Browse files Browse the repository at this point in the history
  • Loading branch information
hjkim-postechdblab committed Nov 7, 2023
1 parent d44dee5 commit 59b8003
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 17 deletions.
5 changes: 4 additions & 1 deletion data/database/nba/dtype_dict.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"play_by_play.neutraldescription": "str",
"play_by_play.visitordescription": "str",
"play_by_play.score": "str",
"play_by_play.scoremargin": "int",
"play_by_play.scoremargin": "str",
"play_by_play.player2_team_nickname": "str",
"play_by_play.player2_team_abbreviation": "str",
"play_by_play.player1_name": "str",
Expand Down Expand Up @@ -239,6 +239,9 @@
"game.team_abbreviation_home": "str",
"game.team_name_home": "str",
"game.team_name_away": "str",
"game.fg3m_away": "int",
"game.fg3a_away": "int",
"game.pf_home": "int",
"game.game_date": "date",
"game.matchup_home": "str",
"game.win_lose_home": "str",
Expand Down
70 changes: 54 additions & 16 deletions src/sql_generator/query_generator_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def check_sql_result(data_manager, sql):
return False


def create_full_outer_view(args, rng, data_manager, schema, hist=None, column_info=None, SEED=1234):
def create_full_outer_view(
args, rng, data_manager, schema, hist=None, column_info=None, SEED=1234
):
args.logger.info("... start generating full outer joined view...")

if args.use_one_predicate:
Expand All @@ -42,7 +44,8 @@ def create_full_outer_view(args, rng, data_manager, schema, hist=None, column_in
# (Advanced, TODO) Group by, Aggregation, NOT EXISTS: Inner query should contain the predicate
if args.predicate_id:
view_name = sql_genetion_utils.get_view_name(
"main_given_predicate", [schema["dataset"], schema["use_cols"], args.predicate_id]
"main_given_predicate",
[schema["dataset"], schema["use_cols"], args.predicate_id],
)
###### This fo view should already exists
else:
Expand All @@ -53,7 +56,9 @@ def create_full_outer_view(args, rng, data_manager, schema, hist=None, column_in
global_column_idx += len(args.table_info[table])
continue
for column in args.table_info[table]:
hist[table][column] = [val for val in hist[table][column] if val[0] is not None]
hist[table][column] = [
val for val in hist[table][column] if val[0] is not None
]
if column in hist[table].keys() and len(hist[table][column]) > 0:
candidate_columns.append((table, column, global_column_idx))
global_column_idx += 1
Expand All @@ -62,18 +67,23 @@ def create_full_outer_view(args, rng, data_manager, schema, hist=None, column_in
"[WARNING] There is no candidate predicates which can be used for generating a full outer join view; start generating for all rows"
)
args.use_one_predicate = False
view_name = sql_genetion_utils.get_view_name("main", [schema["dataset"], schema["use_cols"]])
view_name = sql_genetion_utils.get_view_name(
"main", [schema["dataset"], schema["use_cols"]]
)
else:
chosen_col = candidate_columns[rng.choice(len(candidate_columns))]
chosen_val_idx = rng.randint(0, len(hist[chosen_col[0]][chosen_col[1]]))
chosen_val = hist[chosen_col[0]][chosen_col[1]][chosen_val_idx]

rand_id = f"c{chosen_col[2]}_v{chosen_val_idx}"
view_name = sql_genetion_utils.get_view_name(
"main_given_predicate", [schema["dataset"], schema["use_cols"], rand_id]
"main_given_predicate",
[schema["dataset"], schema["use_cols"], rand_id],
)
else:
view_name = sql_genetion_utils.get_view_name("main", [schema["dataset"], schema["use_cols"]])
view_name = sql_genetion_utils.get_view_name(
"main", [schema["dataset"], schema["use_cols"]]
)

select_columns = []
# <table name>___<column name>
Expand Down Expand Up @@ -142,21 +152,37 @@ def create_full_outer_view(args, rng, data_manager, schema, hist=None, column_in
else:
pred_val_st = chosen_val[0]
pred_val_ta = chosen_val[0] + chosen_val[2]
predicate = f"{pred_col_name} >= {pred_val_st} AND {pred_col_name} < {pred_val_ta}"
predicate = (
f"{pred_col_name} >= {pred_val_st} AND {pred_col_name} < {pred_val_ta}"
)

full_outer_join_sql += f""" WHERE {predicate}; """
else:
full_outer_join_sql += ";"

args.logger.info(f"View SQL: {full_outer_join_sql}")

data_manager.create_view(args.logger, view_name, full_outer_join_sql, type="materialized", drop_if_exists=False)
data_manager.create_view(
args.logger,
view_name,
full_outer_join_sql,
type="materialized",
drop_if_exists=False,
)
args.logger.info("... finished: generating full outer joined view...")

return view_name


def run_generator(data_manager, schema, column_info, args, rng, check_execution_result=True, log_step=1):
def run_generator(
data_manager,
schema,
column_info,
args,
rng,
check_execution_result=False,
log_step=1,
):
all_table_set = set(schema["join_tables"])
join_clause_list = schema["join_clauses"]
join_keys = schema["join_keys"]
Expand All @@ -174,9 +200,13 @@ def run_generator(data_manager, schema, column_info, args, rng, check_execution_
objs = list()

dtype_dict = CaseInsensitiveDict(column_info["dtype_dict"])
args.IDS, args.HASH_CODES, args.NOTES, args.CATEGORIES, args.FOREIGN_KEYS = sql_genetion_utils.set_col_info(
column_info
)
(
args.IDS,
args.HASH_CODES,
args.NOTES,
args.CATEGORIES,
args.FOREIGN_KEYS,
) = sql_genetion_utils.set_col_info(column_info)

# for n in range(1,num_queries+1):
pbar = tqdm(total=args.num_queries)
Expand All @@ -190,7 +220,9 @@ def run_generator(data_manager, schema, column_info, args, rng, check_execution_
with open(inner_query_path, "r") as fp:
q_count = len(fp.readlines())
inner_query_objs += utils.load_objs(inner_query_path + ".obj", q_count)
inner_query_graphs += utils.load_graphs(inner_query_path + ".graph", q_count)
inner_query_graphs += utils.load_graphs(
inner_query_path + ".graph", q_count
)
# .obj type files
else:
inner_query_objs = None
Expand Down Expand Up @@ -314,7 +346,9 @@ def run_generator(data_manager, schema, column_info, args, rng, check_execution_
t_args.__dict__.update(json.load(f))
args = parser.parse_args(namespace=t_args)

formatter = logging.Formatter("[[ %(levelname)s ]]::%(asctime)s::%(funcName)s::%(lineno)d - %(message)s")
formatter = logging.Formatter(
"[[ %(levelname)s ]]::%(asctime)s::%(funcName)s::%(lineno)d - %(message)s"
)

consol_handler = logging.StreamHandler(sys.stdout)
consol_handler.setLevel(logging.DEBUG)
Expand All @@ -339,7 +373,9 @@ def run_generator(data_manager, schema, column_info, args, rng, check_execution_
DBUserID = config["DB"]["data"]["UserID"]
DBUserPW = config["DB"]["data"]["UserPW"]

data_manager, table_info = utils.connect_data_manager(IP, port, DBUserID, DBUserPW, schema)
data_manager, table_info = utils.connect_data_manager(
IP, port, DBUserID, DBUserPW, schema
)
args.table_info = CaseInsensitiveDict(table_info)

COL_INFO = json.load(open(f"{args.data_dir}/{args.db}/dtype_dict.json"))
Expand All @@ -350,6 +386,8 @@ def run_generator(data_manager, schema, column_info, args, rng, check_execution_
args.fo_view_name = fo_view_name
else:
hist = json.load(open(f"{args.data_dir}/{args.db}/selective_histogram.json"))
args.fo_view_name = create_full_outer_view(args, rng, data_manager, schema, hist=hist, column_info=column_info)
args.fo_view_name = create_full_outer_view(
args, rng, data_manager, schema, hist=hist, column_info=column_info
)

run_generator(data_manager, schema, column_info, args, rng)

0 comments on commit 59b8003

Please sign in to comment.