generated from ulavalIFTGLOateliers/GLO2005-Migration
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatabase.py
110 lines (79 loc) · 3.53 KB
/
database.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import pymysql
from dotenv import load_dotenv
from sql_utils import run_sql_file
class Database:
def __init__(self):
"""
Chargez les variables d'environnement de votre fichier .env, puis complétez les lignes 15 à 19 afin de récupérer les valeurs de ces variables
"""
load_dotenv()
self.host = os.environ.get("HOST")
self.port = int(os.environ.get("PORT", 3306))
self.database = os.environ.get("DATABASE")
self.user = os.environ.get("USER")
self.password = os.environ.get("PASSWORD")
self._open_sql_connection()
self.migration_counter = 0
def _open_sql_connection(self):
self.connection = pymysql.connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
db=self.database,
autocommit=True
)
self.cursor = self.connection.cursor()
def push_migration(self):
migration_to_push = self.migration_counter + 1
migration_file = f"db_scripts/migrate_{migration_to_push}.sql"
run_sql_file(self.cursor, migration_file, accept_empty=False)
self.migration_counter += 1
def rollback(self):
if self.migration_counter < 1:
raise ValueError("There are no rollbacks in the rollback stack.")
rollback_file = f"db_scripts/rollback_{self.migration_counter}.sql"
run_sql_file(self.cursor, rollback_file)
self.migration_counter -= 1
def up(self):
self.drop()
run_sql_file(self.cursor, "db_scripts/up.sql")
def drop(self):
run_sql_file(self.cursor, "db_scripts/drop.sql")
self.migration_counter = 0
def get_table_names(self):
req = f"SELECT table_name FROM INFORMATION_SCHEMA.TABLES WHERE table_type = 'BASE TABLE' AND table_schema = '{self.database}';"
self.cursor.execute(req)
res = [x[0] for x in self.cursor.fetchall()]
return res
def get_table_column_names(self, table):
req = f"SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = '{table}' AND TABLE_SCHEMA = '{self.database}' ORDER BY ORDINAL_POSITION;"
self.cursor.execute(req)
res = [x[0] for x in self.cursor.fetchall()]
return res
def get_table_data(self, table):
req = f"SELECT * FROM {table};"
self.cursor.execute(req)
return [list(x) for x in self.cursor.fetchall()]
def get_table_primary_key(self, table):
req = f"SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE WHERE TABLE_SCHEMA = '{self.database}' AND TABLE_NAME = '{table}' AND CONSTRAINT_NAME = 'PRIMARY';"
self.cursor.execute(req)
return self.cursor.fetchone()
def get_table_foreign_keys(self, table):
req = f"SELECT COLUMN_NAME,REFERENCED_TABLE_NAME,REFERENCED_COLUMN_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE WHERE TABLE_SCHEMA = '{self.database}' AND TABLE_NAME = '{table}' AND CONSTRAINT_NAME != 'PRIMARY';"
self.cursor.execute(req)
foreign_keys = []
for foreign_key in self.cursor.fetchall():
foreign_keys.append({
"column_name": foreign_key[0],
"referenced_table_name": foreign_key[1],
"referenced_column_name": foreign_key[2]
})
return foreign_keys
def get_cursor(self):
return self.cursor
def get_connection(self):
return self.connection
def get_migration_stack_size(self):
return self.migration_counter