-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
208 lines (180 loc) · 6.95 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
from packages.exceptions import SqlException
from packages import debug
import sqlparse
import classes
import itertools
db = classes.Database(name="Himalayan Database", tables=[])
def check_overlapping_fields(columns, key):
fl = False
old_key = key
for col in columns:
if col.split(".")[1] == old_key:
if fl:
raise SqlException("EM: Joined tables have one of the fields overlapping. So, you need to specify in <table_name>.<column_name> format")
key = col
fl = True
return key
def where_helper(temp_table, all_columns, where):
"""
Helper function for where; returns only first comparison results,
thereby helping in only where and AND conditions
"""
try:
comparison = where.tokens[2] # comparison = "A=8";
comparison.tokens = [x for x in comparison.tokens if not x.is_whitespace()] # No more white spaces
key = str(comparison.tokens[0]) # key = "A"
if '.' not in key:
key = check_overlapping_fields(all_columns, key)
try:
value = int(str(comparison.tokens[2])) # whether it is an integer value on RHS of comparison or some column
temp_table.delete_rows_by_int(key, value, str(comparison.tokens[1]))
except:
value = str(comparison.tokens[2])
if '.' not in value:
value = check_overlapping_fields(all_columns, value)
temp_table.delete_rows_by_col(key, value, str(comparison.tokens[1]))
except:
raise SqlException("Invalid Syntax")
return temp_table
def where_select_query(temp_table, all_columns, where):
"""
filter where condition on the basis of AND, OR or none
"""
if len(where.tokens) >= 7: # AND or OR are present
if str(where.tokens[4]) == "AND":
temp_table = where_helper(temp_table, all_columns, where)
comparison = where.tokens[6] # comparison = "A=8";
comparison.tokens = [x for x in comparison.tokens if not x.is_whitespace()] # No more white spaces
key = str(comparison.tokens[0]) # key = "A"
if '.' not in key:
key = check_overlapping_fields(all_columns, key)
try:
value = int(str(comparison.tokens[2])) # whether it is an int value on RHS of comparison or some column
temp_table.delete_rows_by_int(key, value, str(comparison.tokens[1]))
except:
value = str(comparison.tokens[2])
if '.' not in value:
value = check_overlapping_fields(all_columns, value)
temp_table.delete_rows_by_col(key, value, str(comparison.tokens[1]))
elif str(where.tokens[4]) == "OR":
comparison1 = where.tokens[2] # comparison = "A=8";
comparison1.tokens = [x for x in comparison1.tokens if not x.is_whitespace()] # No more white spaces
key1 = str(comparison1.tokens[0]) # key = "A"
if '.' not in key1:
key1 = check_overlapping_fields(all_columns, key1)
try:
value1 = int(str(comparison1.tokens[2]))
except:
value1 = str(comparison1.tokens[2])
if '.' not in value1:
value1 = check_overlapping_fields(all_columns, value1)
comparison2 = where.tokens[6] # comparison = "A=8";
comparison2.tokens = [x for x in comparison2.tokens if not x.is_whitespace()] # No more white spaces
key2 = str(comparison2.tokens[0]) # key = "A"
if '.' not in key2:
key2 = check_overlapping_fields(all_columns, key2)
try:
value2 = int(str(comparison2.tokens[2]))
except:
value2 = str(comparison2.tokens[2])
if '.' not in value2:
value2 = check_overlapping_fields(all_columns, value2)
if type(value1) == int and type(value2) == int:
temp_table.delete_rows_by_both_ints(key1, value1, str(comparison1.tokens[1]), key2, value2, str(comparison2.tokens[1]))
elif type(value1) == str and type(value2) == str:
temp_table.delete_rows_by_both_cols(key1, value1, str(comparison1.tokens[1]), key2, value2, str(comparison2.tokens[1]))
else:
raise SqlException("Only OR on joins with either comparisons with int or columns in both conditions supported.")
else:
raise SqlException("Invalid where condition")
elif len(where.tokens) <= 5: # Only where is present
temp_table = where_helper(temp_table, all_columns, where)
else:
raise SqlException("Invalid where syntax")
return temp_table
def select_query(stmt):
"""
returns results of select query
We'll make a temporary table, store the result and then print it
"""
try:
column_list = str(stmt[2]).split(",")
column_list = [x.strip() for x in column_list]
table_list = str(stmt[6]).split(",")
table_list = [x.strip() for x in table_list]
except:
raise SqlException("Invalid Syntax")
else:
if len(table_list) != len(set(table_list)):
raise SqlException("Not Unique Tables")
all_columns = map(lambda x: db.get_table(x).prefix_table_name_to_columns(), table_list)
# upperbound columns of the new table
all_columns = reduce(lambda x, y: x + y, all_columns) # make all column lists in table_list as one
# temporary table with all columns
temp_table = classes.Table(
name="temp",
columns=all_columns,
rows=[]
)
all_tables_rows = map(lambda x: db.get_table(x).get_rows(), table_list)
# cross product of all rows between tables
rows = list(itertools.product(*all_tables_rows)) # NOTE: product() takes the elements of a list, and not the list
# So, here it will be all_tables_rows's elements
for each in rows: # each is a tuple of rows i.e. lists, therefore,
each = reduce(lambda x, y: x + y, each) # reduce simply concatenates all the rows to form one row
#debug.debug(each)
temp_table.add_row(each) # of temporary table
if len(stmt) >= 9: # 'where' is present
where = stmt[8] # where = "WHERE A=8"
if str(where.tokens[0]) == "WHERE":
temp_table = where_select_query(temp_table, all_columns, where)
else:
raise SqlException("Invalid Syntax")
if '*' in column_list:
temp_table.print_contents()
else:
temp = []
for col in column_list:
temp.append(temp_table.get_col(col))
l = len(temp[0])
for row in temp:
if len(row) != l:
raise SqlException("Incompatible column lengths: Generally, this happens when there's an aggregated query wihtout GROUP BY having non-aggregated column")
for col in column_list:
print col + "\t",
print
for row in zip(*temp):
for i in row:
print str(i) + "\t",
print
def main():
db.load_contents()
db.print_contents()
cont = True
while cont:
query = raw_input("hdsql>> ")
try:
if query.upper() == "QUIT":
print "Bye"
cont = False
else:
for command in sqlparse.split(query):
stmt = sqlparse.parse(sqlparse.format(command, keyword_case='upper'))
stmt = stmt[0].tokens
qtype = str(stmt[0])
if len(stmt) < 7:
raise SqlException("Invalid Syntax")
if qtype == "SELECT":
try:
select_query(stmt)
except SqlException, e:
print e.message
except:
raise SqlException("Syntax Error/Query Execution Error")
else:
raise SqlException(qtype + " not supported.")
print
except SqlException, e:
print e.message
if __name__ == '__main__':
main()