Skip to content

Commit

Permalink
Add DSC replicates by default #95
Browse files Browse the repository at this point in the history
  • Loading branch information
gaow committed Mar 13, 2018
1 parent b6e8094 commit eb34e7a
Showing 1 changed file with 28 additions and 10 deletions.
38 changes: 28 additions & 10 deletions src/query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def find_partial_index(xx, ordering):
for ii, i in enumerate(ordering):
if xx.startswith(i):
return ii
if xx.split('.')[1] == 'DSC_REPLICATE':
return -1
raise ValueError(f'{xx} not in list {ordering}')

class Query_Processor:
Expand All @@ -83,7 +85,7 @@ def __init__(self, db, targets, condition = None, groups = None, add_path = Fals
self.pipelines = self.get_pipelines()
# 3. identify and extract which part of each pipeline are involved
# based on tables in target / condition
self.pipelines = self.filter_pipelines()
self.pipelines, self.first_modules = self.filter_pipelines()
# 4. make select / from / where clause
select_clauses, select_fields = self.get_select_clause()
from_clauses = self.get_from_clause()
Expand Down Expand Up @@ -125,6 +127,8 @@ def check_table_field(self, value, check_field = 0):
raise DBError(f"Cannot find module ``{x}`` in DSC results ``{self.db}``.")
k = list(self.data.keys())[keys_lower.index(x.lower())]
y_low = y.lower()
if y_low == 'dsc_replicate':
raise DBError(f'Cannot query on ``DSC_REPLICATE`` in module ``{k}``')
if y_low in [i.lower() for i in self.data[k]] and y_low in [i.lower() for i in self.data['.output'][k]] and check_field == 1:
self.field_warnings[k] = f"Variable ``{y}`` is both parameter and output in module ``{k}``. Parameter variable ``{y}`` is extracted. To obtain output variable ``{y}`` please use ``{k}.output.{y}`` to specify the query target."
if not y_low in [i.lower() for i in self.data[k]] and check_field == 2:
Expand Down Expand Up @@ -207,12 +211,17 @@ def filter_pipelines(self):
for each pipeline extract the sub pipeline that the query involves
'''
res = []
heads = []
tables = uniq_list([x[0].lower() for x in self.target_tables] + [x[0].lower() for x in self.condition_tables])
for pipeline in self.pipelines:
pidx = [l[0] for l in enumerate(pipeline) if l[1] in tables]
# The first module contains replicate info and have to show up
if pidx[0] != 0:
pidx = [0] + pidx
if len(pidx) and not pipeline[pidx[0]:pidx[-1]+1] in res:
res.append(pipeline[pidx[0]:pidx[-1]+1])
return filter_sublist(res)
heads.append(pipeline[0])
return filter_sublist(res), heads

def get_from_clause(self):
res = []
Expand All @@ -221,13 +230,14 @@ def get_from_clause(self):
res.append(('FROM {0} '.format(pipeline[0]) + ' '.join(["INNER JOIN {1} ON {0}.__parent__ = {1}.__id__".format(pipeline[i], pipeline[i+1]) for i in range(len(pipeline) - 1)])).strip())
return res

def get_one_select_clause(self, pipeline):
def get_one_select_clause(self, pipeline, first_module):
clause = []
fields = []
for item in self.target_tables:
# one table in targets do not exist in this pipeline
if item[0] not in pipeline:
continue
# remove table in targets not exist in this pipeline
tables = [item for item in self.target_tables if item[0] in pipeline]
if len(tables):
tables = [(first_module, 'DSC_REPLICATE')] + tables
for item in tables:
fields.append('.'.join(item) if item[1] else item[0])
if item[1] is None:
clause.append("'{0}' AS {0}".format(item[0]))
Expand All @@ -248,10 +258,11 @@ def get_select_clause(self):
select = []
select_fields = []
new_pipelines = []
for pipeline in self.pipelines:
clause, fields = self.get_one_select_clause(pipeline)
for pipeline, first_module in zip(self.pipelines, self.first_modules):
clause, fields = self.get_one_select_clause(pipeline, first_module)
# Caution: should have the same length as input target
if len(fields) != len(self.targets):
# plus 1 because of "DSC_REPLICATE"
if len(fields) != len(self.targets) + 1:
continue
new_pipelines.append(pipeline)
select.append(clause)
Expand Down Expand Up @@ -368,6 +379,13 @@ def merge_tables(self):
table = table.rename(columns = {g: f'{g}:id' for g in self.groups})
table = table[sorted(table.columns, key = lambda x: (find_partial_index(x, targets), not x.endswith(':id')))]
table = table.rename(columns = {f'{g}:id': g for g in self.groups})
# Finally deal with the `DSC_REPLICATE` column
rep_cols = [x for x in table.columns if x.endswith('.DSC_REPLICATE')]
table.insert(0, 'DSC', table.loc[:, rep_cols].apply(lambda x: x.dropna().tolist(), 1))
if not all(table['DSC'].apply(len) == 1):
raise DBError(f'(Possible bug) DSC replicates cannot be merged due to collating entries.')
table['DSC'] = table['DSC'].apply(lambda x: int(x[0]))
table.drop(columns = rep_cols, inplace = True)
return table

def get_queries(self):
Expand Down

0 comments on commit eb34e7a

Please sign in to comment.