diff --git a/app/db/init_db.py b/app/db/init_db.py index 2a83093..9148262 100644 --- a/app/db/init_db.py +++ b/app/db/init_db.py @@ -303,6 +303,7 @@ def initialize_database(): "table_annotation_id": "VARCHAR(64) NOT NULL", "constraint_type": "VARCHAR(16) NOT NULL", "name": "VARCHAR(255)", + "description": "TEXT", "expression": "TEXT", "ref_table": "VARCHAR(255)", "on_update_action": "VARCHAR(16)", diff --git a/app/repository/annotation_repository.py b/app/repository/annotation_repository.py index c32ec61..0a74be9 100644 --- a/app/repository/annotation_repository.py +++ b/app/repository/annotation_repository.py @@ -20,6 +20,11 @@ class AnnotationRepository: + """ + 어노테이션 데이터에 대한 데이터베이스 CRUD 작업을 처리합니다. + 모든 메서드는 내부적으로 `sqlite3`를 사용하여 로컬 DB와 상호작용합니다. + """ + def create_full_annotation( self, db_conn: sqlite3.Connection, @@ -96,6 +101,7 @@ def create_full_annotation( c.table_annotation_id, c.constraint_type, c.name, + c.description, c.expression, c.ref_table, c.on_update_action, @@ -107,8 +113,8 @@ def create_full_annotation( ] cursor.executemany( """ - INSERT INTO table_constraint (id, table_annotation_id, constraint_type, name, expression, ref_table, on_update_action, on_delete_action, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO table_constraint (id, table_annotation_id, constraint_type, name, description, expression, ref_table, on_update_action, on_delete_action, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, constraint_data, ) @@ -182,18 +188,24 @@ def find_full_annotation_by_id(self, annotation_id: str) -> FullAnnotationRespon # 컬럼 정보 cursor.execute( - "SELECT id, column_name, description FROM column_annotation WHERE table_annotation_id = ?", + "SELECT id, column_name, description, data_type, is_nullable, default_value FROM column_annotation WHERE table_annotation_id = ?", (table_id,), ) - columns = [ColumnAnnotationDetail.model_validate(dict(c)) for c in cursor.fetchall()] + columns = [] + for c in cursor.fetchall(): + c_dict = dict(c) + c_dict["is_nullable"] = ( + bool(c_dict["is_nullable"]) if c_dict.get("is_nullable") is not None else None + ) + columns.append(ColumnAnnotationDetail.model_validate(c_dict)) # 제약조건 정보 cursor.execute( """ - SELECT tc.name, tc.constraint_type, ca.column_name + SELECT tc.name, tc.constraint_type, tc.description, ca.column_name FROM table_constraint tc - JOIN constraint_column cc ON tc.id = cc.constraint_id - JOIN column_annotation ca ON cc.column_annotation_id = ca.id + LEFT JOIN constraint_column cc ON tc.id = cc.constraint_id + LEFT JOIN column_annotation ca ON cc.column_annotation_id = ca.id WHERE tc.table_annotation_id = ? """, (table_id,), @@ -201,10 +213,16 @@ def find_full_annotation_by_id(self, annotation_id: str) -> FullAnnotationRespon constraint_map = {} for row in cursor.fetchall(): if row["name"] not in constraint_map: - constraint_map[row["name"]] = {"type": row["constraint_type"], "columns": []} - constraint_map[row["name"]]["columns"].append(row["column_name"]) + constraint_map[row["name"]] = { + "type": row["constraint_type"], + "columns": [], + "description": row["description"], + } + if row["column_name"]: + constraint_map[row["name"]]["columns"].append(row["column_name"]) constraints = [ - ConstraintDetail(name=k, type=v["type"], columns=v["columns"]) for k, v in constraint_map.items() + ConstraintDetail(name=k, type=v["type"], columns=v["columns"], description=v["description"]) + for k, v in constraint_map.items() ] # 인덱스 정보 diff --git a/app/repository/user_db_repository.py b/app/repository/user_db_repository.py index b437572..0ddf8b3 100644 --- a/app/repository/user_db_repository.py +++ b/app/repository/user_db_repository.py @@ -3,6 +3,7 @@ import oracledb +from app.core.enum.db_driver import DBTypesEnum from app.core.exceptions import APIException from app.core.status import CommonCode from app.core.utils import get_db_path @@ -209,50 +210,12 @@ def find_columns( connection = self._connect(driver_module, **kwargs) cursor = connection.cursor() - if db_type == "sqlite": - # SQLite는 PRAGMA를 직접 실행 - pragma_sql = f"PRAGMA table_info('{table_name}')" - cursor.execute(pragma_sql) - columns_raw = cursor.fetchall() - columns = [ - ColumnInfo( - name=c[1], - type=c[2], - nullable=(c[3] == 0), # notnull == 0 means nullable - default=c[4], - comment=None, - is_pk=(c[5] == 1), - ) - for c in columns_raw - ] + if db_type == DBTypesEnum.sqlite.name: + columns = self._find_columns_for_sqlite(cursor, table_name) + elif db_type == DBTypesEnum.postgresql.name: + columns = self._find_columns_for_postgresql(cursor, schema_name, table_name) else: - if "%s" in column_query or "?" in column_query: - cursor.execute(column_query, (schema_name, table_name)) - elif ":owner" in column_query and ":table" in column_query: - owner_bind = schema_name.upper() if schema_name else schema_name - table_bind = table_name.upper() if table_name else table_name - try: - cursor.execute(column_query, {"owner": owner_bind, "table": table_bind}) - except Exception: - try: - pos_query = column_query.replace(":owner", ":1").replace(":table", ":2") - cursor.execute(pos_query, [owner_bind, table_bind]) - except Exception as e: - raise APIException(CommonCode.FAIL) from e - else: - cursor.execute(column_query) - - columns = [ - ColumnInfo( - name=c[0], - type=c[1], - nullable=(c[2] in ["YES", "Y", True]), - default=c[3], - comment=c[4] if len(c) > 4 else None, - is_pk=(c[5] in ["PRI", True] if len(c) > 5 else False), - ) - for c in cursor.fetchall() - ] + columns = self._find_columns_for_general(cursor, column_query, schema_name, table_name) return ColumnListResult(is_successful=True, code=CommonCode.SUCCESS_FIND_COLUMNS, columns=columns) except Exception: @@ -261,94 +224,310 @@ def find_columns( if connection: connection.close() + def _find_columns_for_sqlite(self, cursor: Any, table_name: str) -> list[ColumnInfo]: + pragma_sql = f"PRAGMA table_info('{table_name}')" + cursor.execute(pragma_sql) + columns_raw = cursor.fetchall() + # SQLite는 pragma에서 순서(cid)를 반환하지만, ordinal_position은 1부터 시작하는 표준이므로 +1 + return [ + ColumnInfo( + name=c[1], + type=c[2], + nullable=(c[3] == 0), + default=c[4], + comment=None, + is_pk=(c[5] == 1), + ordinal_position=c[0] + 1, + ) + for c in columns_raw + ] + + def _find_columns_for_postgresql(self, cursor: Any, schema_name: str, table_name: str) -> list[ColumnInfo]: + sql = """ + SELECT + c.column_name, + c.udt_name, + c.is_nullable, + c.column_default, + c.ordinal_position, + (SELECT pg_catalog.col_description(cls.oid, c.dtd_identifier::int) + FROM pg_catalog.pg_class cls + JOIN pg_catalog.pg_namespace n ON n.oid = cls.relnamespace + WHERE cls.relname = c.table_name AND n.nspname = c.table_schema) as comment, + CASE WHEN kcu.column_name IS NOT NULL THEN TRUE ELSE FALSE END as is_pk + FROM + information_schema.columns c + LEFT JOIN information_schema.key_column_usage kcu + ON c.table_schema = kcu.table_schema + AND c.table_name = kcu.table_name + AND c.column_name = kcu.column_name + AND kcu.constraint_name IN ( + SELECT constraint_name + FROM information_schema.table_constraints + WHERE table_schema = %s + AND table_name = %s + AND constraint_type = 'PRIMARY KEY' + ) + WHERE + c.table_schema = %s AND c.table_name = %s + ORDER BY + c.ordinal_position; + """ + cursor.execute(sql, (schema_name, table_name, schema_name, table_name)) + columns_raw = cursor.fetchall() + return [ + ColumnInfo( + name=c[0], + type=c[1], + nullable=(c[2] == "YES"), + default=c[3], + ordinal_position=c[4], + comment=c[5], + is_pk=c[6], + ) + for c in columns_raw + ] + + def _find_columns_for_general( + self, cursor: Any, column_query: str, schema_name: str, table_name: str + ) -> list[ColumnInfo]: + if "%s" in column_query or "?" in column_query: + cursor.execute(column_query, (schema_name, table_name)) + elif ":owner" in column_query and ":table" in column_query: + owner_bind = schema_name.upper() if schema_name else schema_name + table_bind = table_name.upper() if table_name else table_name + try: + cursor.execute(column_query, {"owner": owner_bind, "table": table_bind}) + except Exception: + try: + pos_query = column_query.replace(":owner", ":1").replace(":table", ":2") + cursor.execute(pos_query, [owner_bind, table_bind]) + except Exception as e: + raise APIException(CommonCode.FAIL) from e + else: + cursor.execute(column_query) + + columns = [] + for c in cursor.fetchall(): + data_type = c[1] + if c[6] is not None: + data_type = f"{data_type}({c[6]})" + elif c[7] is not None and c[8] is not None: + data_type = f"{data_type}({c[7]}, {c[8]})" + + columns.append( + ColumnInfo( + name=c[0], + type=data_type, + nullable=(c[2] in ["YES", "Y", True]), + default=c[3], + comment=c[4] if len(c) > 4 else None, + is_pk=(c[5] in ["PRI", True] if len(c) > 5 else False), + ) + ) + return columns + def find_constraints( - self, driver_module: Any, db_type: str, table_name: str, **kwargs: Any + self, driver_module: Any, db_type: str, schema_name: str, table_name: str, **kwargs: Any ) -> list[ConstraintInfo]: """ 테이블의 제약 조건 정보를 조회합니다. - - 현재는 SQLite만 지원합니다. + - 현재는 SQLite, PostgreSQL만 지원합니다. - 실패 시 DB 드라이버의 예외를 직접 발생시킵니다. """ connection = None try: connection = self._connect(driver_module, **kwargs) cursor = connection.cursor() - constraints = [] - - if db_type == "sqlite": - # Foreign Key 제약 조건 조회 - fk_list_sql = f"PRAGMA foreign_key_list('{table_name}')" - cursor.execute(fk_list_sql) - fks = cursor.fetchall() - - # Foreign Key 정보를 그룹화 - fk_groups = {} - for fk in fks: - fk_id = fk[0] - if fk_id not in fk_groups: - fk_groups[fk_id] = {"referenced_table": fk[2], "columns": [], "referenced_columns": []} - fk_groups[fk_id]["columns"].append(fk[3]) - fk_groups[fk_id]["referenced_columns"].append(fk[4]) - - for _, group in fk_groups.items(): - constraints.append( - ConstraintInfo( - name=f"fk_{table_name}_{'_'.join(group['columns'])}", - type="FOREIGN KEY", - columns=group["columns"], - referenced_table=group["referenced_table"], - referenced_columns=group["referenced_columns"], - ) - ) - - # 다른 DB 타입에 대한 제약 조건 조회 로직 추가 가능 - # elif db_type == "postgresql": ... - - return constraints + + if db_type == DBTypesEnum.sqlite.name: + return self._find_constraints_for_sqlite(cursor, table_name) + elif db_type == DBTypesEnum.postgresql.name: + return self._find_constraints_for_postgresql(cursor, schema_name, table_name) + # elif db_type == ...: + return [] finally: if connection: connection.close() - def find_indexes(self, driver_module: Any, db_type: str, table_name: str, **kwargs: Any) -> list[IndexInfo]: + def _find_constraints_for_sqlite(self, cursor: Any, table_name: str) -> list[ConstraintInfo]: + constraints = [] + fk_list_sql = f"PRAGMA foreign_key_list('{table_name}')" + cursor.execute(fk_list_sql) + fks = cursor.fetchall() + + # Foreign Key 정보를 그룹화 + fk_groups = {} + for fk in fks: + fk_id = fk[0] + if fk_id not in fk_groups: + fk_groups[fk_id] = {"referenced_table": fk[2], "columns": [], "referenced_columns": []} + fk_groups[fk_id]["columns"].append(fk[3]) + fk_groups[fk_id]["referenced_columns"].append(fk[4]) + + for _, group in fk_groups.items(): + constraints.append( + ConstraintInfo( + name=f"fk_{table_name}_{'_'.join(group['columns'])}", + type="FOREIGN KEY", + columns=group["columns"], + referenced_table=group["referenced_table"], + referenced_columns=group["referenced_columns"], + ) + ) + return constraints + + def _find_constraints_for_postgresql(self, cursor: Any, schema_name: str, table_name: str) -> list[ConstraintInfo]: + sql = """ + SELECT + tc.constraint_name, + tc.constraint_type, + kcu.column_name, + rc.update_rule, + rc.delete_rule, + ccu.table_name AS foreign_table_name, + ccu.column_name AS foreign_column_name, + chk.check_clause + FROM + information_schema.table_constraints tc + LEFT JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema AND tc.table_name = kcu.table_name + LEFT JOIN information_schema.referential_constraints rc + ON tc.constraint_name = rc.constraint_name AND tc.table_schema = rc.constraint_schema + LEFT JOIN information_schema.constraint_column_usage ccu + ON rc.unique_constraint_name = ccu.constraint_name AND rc.unique_constraint_schema = ccu.table_schema + LEFT JOIN information_schema.check_constraints chk + ON tc.constraint_name = chk.constraint_name AND tc.table_schema = chk.constraint_schema + WHERE + tc.table_schema = %s AND tc.table_name = %s; + """ + cursor.execute(sql, (schema_name, table_name)) + raw_constraints = cursor.fetchall() + + constraint_map = {} + for row in raw_constraints: + # Filter out 'NOT NULL' constraints which are handled by `is_nullable` in column info + const_type = row[1] + check_clause = row[7] + if const_type == "CHECK" and check_clause and "IS NOT NULL" in check_clause.upper(): + continue + + (name, _, column, on_update, on_delete, ref_table, ref_column, check_expr) = row + if name not in constraint_map: + constraint_map[name] = { + "type": const_type, + "columns": [], + "referenced_table": ref_table, + "referenced_columns": [], + "check_expression": check_expr, + "on_update": on_update, + "on_delete": on_delete, + } + if column and column not in constraint_map[name]["columns"]: + constraint_map[name]["columns"].append(column) + if ref_column and ref_column not in constraint_map[name]["referenced_columns"]: + constraint_map[name]["referenced_columns"].append(ref_column) + + return [ + ConstraintInfo( + name=name, + type=data["type"], + columns=data["columns"], + referenced_table=data["referenced_table"], + referenced_columns=data["referenced_columns"] if data["referenced_columns"] else None, + check_expression=data["check_expression"], + on_update=data["on_update"], + on_delete=data["on_delete"], + ) + for name, data in constraint_map.items() + ] + + def find_indexes( + self, driver_module: Any, db_type: str, schema_name: str, table_name: str, **kwargs: Any + ) -> list[IndexInfo]: """ 테이블의 인덱스 정보를 조회합니다. - - 현재는 SQLite만 지원합니다. - 실패 시 DB 드라이버의 예외를 직접 발생시킵니다. """ connection = None try: connection = self._connect(driver_module, **kwargs) cursor = connection.cursor() - indexes = [] - - if db_type == "sqlite": - index_list_sql = f"PRAGMA index_list('{table_name}')" - cursor.execute(index_list_sql) - raw_indexes = cursor.fetchall() - - for idx in raw_indexes: - index_name = idx[1] - is_unique = idx[2] == 1 - # "sqlite_autoindex_"로 시작하는 인덱스는 PK에 의해 자동 생성된 것이므로 제외 - if index_name.startswith("sqlite_autoindex_"): - continue - - index_info_sql = f"PRAGMA index_info('{index_name}')" - cursor.execute(index_info_sql) - index_columns = [row[2] for row in cursor.fetchall()] - - if index_columns: - indexes.append(IndexInfo(name=index_name, columns=index_columns, is_unique=is_unique)) - - # 다른 DB 타입에 대한 인덱스 조회 로직 추가 가능 - # elif db_type == "postgresql": ... - - return indexes + if db_type == DBTypesEnum.sqlite.name: + return self._find_indexes_for_sqlite(cursor, table_name) + elif db_type == DBTypesEnum.postgresql.name: + return self._find_indexes_for_postgresql(cursor, schema_name, table_name) + # elif db_type == ...: + return [] finally: if connection: connection.close() + def _find_indexes_for_sqlite(self, cursor: Any, table_name: str) -> list[IndexInfo]: + indexes = [] + index_list_sql = f"PRAGMA index_list('{table_name}')" + cursor.execute(index_list_sql) + raw_indexes = cursor.fetchall() + + for idx in raw_indexes: + index_name = idx[1] + is_unique = idx[2] == 1 + + # "sqlite_autoindex_"로 시작하는 인덱스는 PK에 의해 자동 생성된 것이므로 제외 + if index_name.startswith("sqlite_autoindex_"): + continue + + index_info_sql = f"PRAGMA index_info('{index_name}')" + cursor.execute(index_info_sql) + index_columns = [row[2] for row in cursor.fetchall()] + + if index_columns: + indexes.append(IndexInfo(name=index_name, columns=index_columns, is_unique=is_unique)) + return indexes + + def _find_indexes_for_postgresql(self, cursor: Any, schema_name: str, table_name: str) -> list[IndexInfo]: + sql = """ + SELECT + i.relname as index_name, + a.attname as column_name, + ix.indisunique as is_unique, + ix.indisprimary as is_primary + FROM + pg_class t, + pg_class i, + pg_index ix, + pg_attribute a, + pg_namespace n + WHERE + t.oid = ix.indrelid + and i.oid = ix.indexrelid + and a.attrelid = t.oid + and a.attnum = ANY(ix.indkey) + and t.relkind = 'r' + and n.oid = t.relnamespace + and n.nspname = %s + and t.relname = %s + ORDER BY + i.relname, a.attnum; + """ + cursor.execute(sql, (schema_name, table_name)) + raw_indexes = cursor.fetchall() + + index_map = {} + for row in raw_indexes: + index_name, column_name, is_unique, is_primary = row + if is_primary: # Exclude indexes created for PRIMARY KEY constraints + continue + if index_name not in index_map: + index_map[index_name] = {"columns": [], "is_unique": is_unique} + index_map[index_name]["columns"].append(column_name) + + return [ + IndexInfo(name=name, columns=data["columns"], is_unique=data["is_unique"]) + for name, data in index_map.items() + ] + # ───────────────────────────── # DB 연결 메서드 # ───────────────────────────── diff --git a/app/schemas/annotation/db_model.py b/app/schemas/annotation/db_model.py index 667765b..8a73aa9 100644 --- a/app/schemas/annotation/db_model.py +++ b/app/schemas/annotation/db_model.py @@ -39,6 +39,7 @@ class TableConstraintInDB(AnnotationBase): table_annotation_id: str constraint_type: ConstraintTypeEnum name: str | None = None + description: str | None = None expression: str | None = None ref_table: str | None = None on_update_action: str | None = None diff --git a/app/schemas/annotation/response_model.py b/app/schemas/annotation/response_model.py index 5602e0e..5ec4556 100644 --- a/app/schemas/annotation/response_model.py +++ b/app/schemas/annotation/response_model.py @@ -10,6 +10,9 @@ class ColumnAnnotationDetail(BaseModel): id: str column_name: str description: str | None = None + data_type: str | None = None + is_nullable: bool | None = None + default_value: str | None = None class ConstraintDetail(BaseModel): diff --git a/app/schemas/user_db/result_model.py b/app/schemas/user_db/result_model.py index 64d9190..d379019 100644 --- a/app/schemas/user_db/result_model.py +++ b/app/schemas/user_db/result_model.py @@ -53,6 +53,7 @@ class ColumnInfo(BaseModel): default: Any | None = Field(None, description="기본값") comment: str | None = Field(None, description="코멘트") is_pk: bool = Field(False, description="기본 키(Primary Key) 여부") + ordinal_position: int | None = Field(None, description="컬럼 순서") class ConstraintInfo(BaseModel): @@ -64,6 +65,8 @@ class ConstraintInfo(BaseModel): # FOREIGN KEY 관련 필드 referenced_table: str | None = Field(None, description="참조하는 테이블 (FK)") referenced_columns: list[str] | None = Field(None, description="참조하는 테이블의 컬럼 (FK)") + on_update: str | None = Field(None, description="UPDATE 시 동작 (FK)") + on_delete: str | None = Field(None, description="DELETE 시 동작 (FK)") # CHECK 관련 필드 check_expression: str | None = Field(None, description="CHECK 제약 조건 표현식") diff --git a/app/services/annotation_service.py b/app/services/annotation_service.py index b217586..c7a7ce7 100644 --- a/app/services/annotation_service.py +++ b/app/services/annotation_service.py @@ -36,6 +36,13 @@ class AnnotationService: def __init__( self, repository: AnnotationRepository = annotation_repository, user_db_serv: UserDbService = user_db_service ): + """ + AnnotationService를 초기화합니다. + + Args: + repository (AnnotationRepository): 어노테이션 레포지토리 의존성 주입. + user_db_serv (UserDbService): 사용자 DB 서비스 의존성 주입. + """ self.repository = repository self.user_db_service = user_db_serv @@ -65,7 +72,9 @@ async def create_annotation(self, request: AnnotationCreateRequest) -> FullAnnot conn = sqlite3.connect(str(db_path), timeout=10) conn.execute("BEGIN") - db_models = self._transform_ai_response_to_db_models(ai_response, db_profile, request.db_profile_id) + db_models = self._transform_ai_response_to_db_models( + ai_response, db_profile, request.db_profile_id, full_schema_info + ) self.repository.create_full_annotation(db_conn=conn, **db_models) conn.commit() @@ -82,21 +91,38 @@ async def create_annotation(self, request: AnnotationCreateRequest) -> FullAnnot return self.get_full_annotation(annotation_id) def _transform_ai_response_to_db_models( - self, ai_response: dict[str, Any], db_profile: AllDBProfileInfo, db_profile_id: str + self, + ai_response: dict[str, Any], + db_profile: AllDBProfileInfo, + db_profile_id: str, + full_schema_info: list[UserDBTableInfo], ) -> dict[str, Any]: + """ + AI 서버의 응답을 받아서 DB에 저장할 수 있는 모델 딕셔너리로 변환합니다. + """ now = datetime.now() annotation_id = generate_prefixed_uuid(DBSaveIdEnum.database_annotation.value) + # 원본 스키마 정보를 쉽게 조회할 수 있도록 룩업 테이블 생성 + schema_lookup: dict[str, UserDBTableInfo] = {table.name: table for table in full_schema_info} + db_anno = DatabaseAnnotationInDB( id=annotation_id, db_profile_id=db_profile_id, - database_name=db_profile.name, + database_name=db_profile.name or db_profile.username, description=ai_response.get("database_annotation"), created_at=now, updated_at=now, ) - table_annos, col_annos, constraint_annos, constraint_col_annos, index_annos, index_col_annos = ( + ( + all_table_annos, + all_col_annos, + all_constraint_annos, + all_constraint_col_annos, + all_index_annos, + all_index_col_annos, + ) = ( [], [], [], @@ -106,92 +132,184 @@ def _transform_ai_response_to_db_models( ) for tbl_data in ai_response.get("tables", []): - table_id = generate_prefixed_uuid(DBSaveIdEnum.table_annotation.value) - table_annos.append( - TableAnnotationInDB( - id=table_id, - database_annotation_id=annotation_id, - table_name=tbl_data["table_name"], - description=tbl_data.get("annotation"), + original_table = schema_lookup.get(tbl_data["table_name"]) + if not original_table: + continue + + ( + table_anno, + col_annos, + constraint_annos, + constraint_col_annos, + index_annos, + index_col_annos, + ) = self._create_annotations_for_table(tbl_data, original_table, annotation_id, now) + + all_table_annos.append(table_anno) + all_col_annos.extend(col_annos) + all_constraint_annos.extend(constraint_annos) + all_constraint_col_annos.extend(constraint_col_annos) + all_index_annos.extend(index_annos) + all_index_col_annos.extend(index_col_annos) + + return { + "db_annotation": db_anno, + "table_annotations": all_table_annos, + "column_annotations": all_col_annos, + "constraint_annotations": all_constraint_annos, + "constraint_column_annotations": all_constraint_col_annos, + "index_annotations": all_index_annos, + "index_column_annotations": all_index_col_annos, + } + + def _create_annotations_for_table( + self, + tbl_data: dict[str, Any], + original_table: UserDBTableInfo, + database_annotation_id: str, + now: datetime, + ) -> tuple: + """ + 단일 테이블에 대한 모든 하위 어노테이션(컬럼, 제약조건, 인덱스)을 생성합니다. + """ + table_id = generate_prefixed_uuid(DBSaveIdEnum.table_annotation.value) + table_anno = TableAnnotationInDB( + id=table_id, + database_annotation_id=database_annotation_id, + table_name=original_table.name, + description=tbl_data.get("annotation"), + created_at=now, + updated_at=now, + ) + + col_map = { + col.name: generate_prefixed_uuid(DBSaveIdEnum.column_annotation.value) for col in original_table.columns + } + + col_annos = self._process_columns(tbl_data, original_table, table_id, col_map, now) + constraint_annos, constraint_col_annos = self._process_constraints( + tbl_data, original_table, table_id, col_map, now + ) + index_annos, index_col_annos = self._process_indexes(tbl_data, original_table, table_id, col_map, now) + + return table_anno, col_annos, constraint_annos, constraint_col_annos, index_annos, index_col_annos + + def _process_columns( + self, tbl_data: dict, original_table: UserDBTableInfo, table_id: str, col_map: dict, now: datetime + ) -> list[ColumnAnnotationInDB]: + """ + 테이블의 컬럼 어노테이션 모델 리스트를 생성합니다. + """ + col_annos = [] + for col_data in tbl_data.get("columns", []): + original_column = next((c for c in original_table.columns if c.name == col_data["column_name"]), None) + if not original_column: + continue + col_annos.append( + ColumnAnnotationInDB( + id=col_map[original_column.name], + table_annotation_id=table_id, + column_name=original_column.name, + data_type=original_column.type, + is_nullable=1 if original_column.nullable else 0, + default_value=original_column.default, + description=col_data.get("annotation"), + ordinal_position=original_column.ordinal_position, created_at=now, updated_at=now, ) ) + return col_annos - col_map = { - col["column_name"]: generate_prefixed_uuid(DBSaveIdEnum.column_annotation.value) - for col in tbl_data.get("columns", []) - } - - for col_data in tbl_data.get("columns", []): - col_annos.append( - ColumnAnnotationInDB( - id=col_map[col_data["column_name"]], - table_annotation_id=table_id, - column_name=col_data["column_name"], - description=col_data.get("annotation"), - created_at=now, - updated_at=now, - ) + def _process_constraints( + self, tbl_data: dict, original_table: UserDBTableInfo, table_id: str, col_map: dict, now: datetime + ) -> tuple[list[TableConstraintInDB], list[ConstraintColumnInDB]]: + """ + 테이블의 제약조건 및 제약조건 컬럼 어노테이션 모델 리스트를 생성합니다. + """ + constraint_annos, constraint_col_annos = [], [] + for const_data in tbl_data.get("constraints", []): + original_constraint = next((c for c in original_table.constraints if c.name == const_data["name"]), None) + if not original_constraint: + continue + const_id = generate_prefixed_uuid(DBSaveIdEnum.table_constraint.value) + constraint_annos.append( + TableConstraintInDB( + id=const_id, + table_annotation_id=table_id, + name=original_constraint.name, + constraint_type=ConstraintTypeEnum(original_constraint.type), + description=const_data.get("annotation"), + ref_table=original_constraint.referenced_table, + expression=original_constraint.check_expression, + on_update_action=original_constraint.on_update, + on_delete_action=original_constraint.on_delete, + created_at=now, + updated_at=now, ) - - for const_data in tbl_data.get("constraints", []): - const_id = generate_prefixed_uuid(DBSaveIdEnum.table_constraint.value) - constraint_annos.append( - TableConstraintInDB( - id=const_id, - table_annotation_id=table_id, - name=const_data["name"], - constraint_type=ConstraintTypeEnum(const_data["type"]), + ) + for i, col_name in enumerate(original_constraint.columns): + if col_name not in col_map: + continue + constraint_col_annos.append( + ConstraintColumnInDB( + id=generate_prefixed_uuid(DBSaveIdEnum.constraint_column.value), + constraint_id=const_id, + column_annotation_id=col_map[col_name], + position=i + 1, + referenced_column_name=( + original_constraint.referenced_columns[i] + if original_constraint.referenced_columns + and i < len(original_constraint.referenced_columns) + else None + ), created_at=now, updated_at=now, ) ) - for col_name in const_data.get("columns", []): - constraint_col_annos.append( - ConstraintColumnInDB( - id=generate_prefixed_uuid(DBSaveIdEnum.constraint_column.value), - constraint_id=const_id, - column_annotation_id=col_map[col_name], - created_at=now, - updated_at=now, - ) - ) + return constraint_annos, constraint_col_annos - for idx_data in tbl_data.get("indexes", []): - idx_id = generate_prefixed_uuid(DBSaveIdEnum.index_annotation.value) - index_annos.append( - IndexAnnotationInDB( - id=idx_id, - table_annotation_id=table_id, - name=idx_data["name"], - is_unique=1 if idx_data.get("is_unique") else 0, + def _process_indexes( + self, tbl_data: dict, original_table: UserDBTableInfo, table_id: str, col_map: dict, now: datetime + ) -> tuple[list[IndexAnnotationInDB], list[IndexColumnInDB]]: + """ + 테이블의 인덱스 및 인덱스 컬럼 어노테이션 모델 리스트를 생성합니다. + """ + index_annos, index_col_annos = [], [] + for idx_data in tbl_data.get("indexes", []): + original_index = next((i for i in original_table.indexes if i.name == idx_data["name"]), None) + if not original_index: + continue + idx_id = generate_prefixed_uuid(DBSaveIdEnum.index_annotation.value) + index_annos.append( + IndexAnnotationInDB( + id=idx_id, + table_annotation_id=table_id, + name=original_index.name, + is_unique=1 if original_index.is_unique else 0, + created_at=now, + updated_at=now, + ) + ) + for i, col_name in enumerate(original_index.columns): + if col_name not in col_map: + continue + index_col_annos.append( + IndexColumnInDB( + id=generate_prefixed_uuid(DBSaveIdEnum.index_column.value), + index_id=idx_id, + column_annotation_id=col_map[col_name], + position=i + 1, created_at=now, updated_at=now, ) ) - for col_name in idx_data.get("columns", []): - index_col_annos.append( - IndexColumnInDB( - id=generate_prefixed_uuid(DBSaveIdEnum.index_column.value), - index_id=idx_id, - column_annotation_id=col_map[col_name], - created_at=now, - updated_at=now, - ) - ) - - return { - "db_annotation": db_anno, - "table_annotations": table_annos, - "column_annotations": col_annos, - "constraint_annotations": constraint_annos, - "constraint_column_annotations": constraint_col_annos, - "index_annotations": index_annos, - "index_column_annotations": index_col_annos, - } + return index_annos, index_col_annos def get_full_annotation(self, annotation_id: str) -> FullAnnotationResponse: + """ + ID를 기반으로 완전한 어노테이션 정보를 조회합니다. + """ try: annotation = self.repository.find_full_annotation_by_id(annotation_id) if not annotation: @@ -201,6 +319,9 @@ def get_full_annotation(self, annotation_id: str) -> FullAnnotationResponse: raise APIException(CommonCode.FAIL_FIND_ANNOTATION) from e def delete_annotation(self, annotation_id: str) -> AnnotationDeleteResponse: + """ + ID를 기반으로 어노테이션 및 관련 하위 데이터를 모두 삭제합니다. + """ try: is_deleted = self.repository.delete_annotation_by_id(annotation_id) if not is_deleted: diff --git a/app/services/user_db_service.py b/app/services/user_db_service.py index 3afe6fd..eff501a 100644 --- a/app/services/user_db_service.py +++ b/app/services/user_db_service.py @@ -192,10 +192,12 @@ def get_full_schema_info( try: constraints = repository.find_constraints( - driver_module, db_info.type, table_name, **connect_kwargs + driver_module, db_info.type, schema_name, table_name, **connect_kwargs ) - indexes = repository.find_indexes(driver_module, db_info.type, table_name, **connect_kwargs) - except sqlite3.Error as e: + indexes = repository.find_indexes( + driver_module, db_info.type, schema_name, table_name, **connect_kwargs + ) + except (sqlite3.Error, self._get_driver_module(db_info.type).Error) as e: # 레포지토리에서 발생한 DB 예외를 서비스에서 처리 raise APIException(CommonCode.FAIL_FIND_CONSTRAINTS_OR_INDEXES) from e @@ -268,7 +270,7 @@ def _get_schema_query(self, db_type: str) -> str | None: if db_type == "postgresql": return """ SELECT schema_name FROM information_schema.schemata - WHERE schema_name NOT IN ('pg_catalog', 'information_schema') + WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') """ elif db_type in ["mysql", "mariadb"]: return "SELECT schema_name FROM information_schema.schemata" @@ -288,7 +290,7 @@ def _get_table_query(self, db_type: str, for_all_schemas: bool = False) -> str | """ else: return """ - SELECT table_name, table_schema FROM information_schema.tables + SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = %s """ elif db_type in ["mysql", "mariadb"]: @@ -312,11 +314,38 @@ def _get_column_query(self, db_type: str) -> str | None: db_type = db_type.lower() if db_type == "postgresql": return """ - SELECT column_name, data_type, is_nullable, column_default, table_name, table_schema - FROM information_schema.columns - WHERE table_schema NOT IN ('pg_catalog', 'information_schema') - AND table_schema = %s - AND table_name = %s + SELECT + c.column_name, + c.data_type, + c.is_nullable, + c.column_default, + pgd.description AS comment, + ( + SELECT TRUE + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + WHERE tc.constraint_type = 'PRIMARY KEY' + AND tc.table_schema = c.table_schema + AND tc.table_name = c.table_name + AND kcu.column_name = c.column_name + ) AS is_pk, + c.character_maximum_length, + c.numeric_precision, + c.numeric_scale + FROM + information_schema.columns c + LEFT JOIN + pg_catalog.pg_stat_all_tables st + ON c.table_schema = st.schemaname AND c.table_name = st.relname + LEFT JOIN + pg_catalog.pg_description pgd + ON pgd.objoid = st.relid AND pgd.objsubid = c.ordinal_position + WHERE + c.table_schema = %s AND c.table_name = %s + ORDER BY + c.ordinal_position; """ elif db_type in ["mysql", "mariadb"]: return """