diff --git a/app/api/annotation_api.py b/app/api/annotation_api.py index aa5ce45..f974189 100644 --- a/app/api/annotation_api.py +++ b/app/api/annotation_api.py @@ -43,6 +43,22 @@ def get_annotation( return ResponseMessage.success(value=annotation, code=CommonCode.SUCCESS_FIND_ANNOTATION) +@router.get( + "/find/db/{db_profile_id}", + response_model=ResponseMessage[FullAnnotationResponse], + summary="DB 프로필 ID로 어노테이션 조회", +) +def get_annotation_by_db_profile_id( + db_profile_id: str, + service: AnnotationService = annotation_service_dependency, +) -> ResponseMessage[FullAnnotationResponse]: + """ + `db_profile_id`에 연결된 어노테이션의 전체 상세 정보를 조회합니다. + """ + annotation = service.get_annotation_by_db_profile_id(db_profile_id) + return ResponseMessage.success(value=annotation, code=CommonCode.SUCCESS_FIND_ANNOTATION) + + @router.delete( "/remove/{annotation_id}", response_model=ResponseMessage[AnnotationDeleteResponse], diff --git a/app/core/status.py b/app/core/status.py index 22d4488..e1574dd 100644 --- a/app/core/status.py +++ b/app/core/status.py @@ -59,6 +59,7 @@ class CommonCode(Enum): """ DRIVER, DB 클라이언트 에러 코드 - 41xx """ INVALID_DB_DRIVER = (status.HTTP_409_CONFLICT, "4100", "지원하지 않는 데이터베이스입니다.") NO_DB_DRIVER = (status.HTTP_400_BAD_REQUEST, "4101", "데이터베이스는 필수 값입니다.") + NO_DB_PROFILE_FOUND = (status.HTTP_404_NOT_FOUND, "4102", "해당 ID의 DB 프로필을 찾을 수 없습니다.") """ KEY 클라이언트 에러 코드 - 42xx """ INVALID_API_KEY_FORMAT = (status.HTTP_400_BAD_REQUEST, "4200", "API 키의 형식이 올바르지 않습니다.") @@ -86,6 +87,7 @@ class CommonCode(Enum): """ ANNOTATION 클라이언트 에러 코드 - 44xx """ INVALID_ANNOTATION_REQUEST = (status.HTTP_400_BAD_REQUEST, "4400", "어노테이션 요청 데이터가 유효하지 않습니다.") + NO_ANNOTATION_FOR_PROFILE = (status.HTTP_404_NOT_FOUND, "4401", "해당 DB 프로필에 연결된 어노테이션이 없습니다.") """ SQL 클라이언트 에러 코드 - 45xx """ @@ -116,6 +118,7 @@ class CommonCode(Enum): "5105", "디비 제약조건 또는 인덱스 정보 조회 중 에러가 발생했습니다.", ) + FAIL_FIND_SAMPLE_ROWS = (status.HTTP_500_INTERNAL_SERVER_ERROR, "5106", "샘플 데이터 조회 중 에러가 발생했습니다.") FAIL_SAVE_PROFILE = (status.HTTP_500_INTERNAL_SERVER_ERROR, "5130", "디비 정보 저장 중 에러가 발생했습니다.") FAIL_UPDATE_PROFILE = (status.HTTP_500_INTERNAL_SERVER_ERROR, "5150", "디비 정보 업데이트 중 에러가 발생했습니다.") FAIL_DELETE_PROFILE = (status.HTTP_500_INTERNAL_SERVER_ERROR, "5170", "디비 정보 삭제 중 에러가 발생했습니다.") diff --git a/app/db/init_db.py b/app/db/init_db.py index 2a83093..b90cf66 100644 --- a/app/db/init_db.py +++ b/app/db/init_db.py @@ -80,14 +80,18 @@ def initialize_database(): "username": "VARCHAR(128)", "password": "VARCHAR(128)", "view_name": "VARCHAR(64)", + "annotation_id": "VARCHAR(64)", "created_at": "DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP", "updated_at": "DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP", + "FOREIGN KEY (annotation_id)": "REFERENCES database_annotation(id) ON DELETE SET NULL", } create_sql = ( f"CREATE TABLE IF NOT EXISTS db_profile ({', '.join([f'{k} {v}' for k, v in db_profile_cols.items()])})" ) cursor.execute(create_sql) - _synchronize_table(cursor, "db_profile", db_profile_cols) + _synchronize_table( + cursor, "db_profile", {k: v for k, v in db_profile_cols.items() if not k.startswith("FOREIGN KEY")} + ) cursor.execute( """ @@ -303,6 +307,7 @@ def initialize_database(): "table_annotation_id": "VARCHAR(64) NOT NULL", "constraint_type": "VARCHAR(16) NOT NULL", "name": "VARCHAR(255)", + "description": "TEXT", "expression": "TEXT", "ref_table": "VARCHAR(255)", "on_update_action": "VARCHAR(16)", diff --git a/app/repository/annotation_repository.py b/app/repository/annotation_repository.py index c32ec61..8ad55c3 100644 --- a/app/repository/annotation_repository.py +++ b/app/repository/annotation_repository.py @@ -20,6 +20,11 @@ class AnnotationRepository: + """ + 어노테이션 데이터에 대한 데이터베이스 CRUD 작업을 처리합니다. + 모든 메서드는 내부적으로 `sqlite3`를 사용하여 로컬 DB와 상호작용합니다. + """ + def create_full_annotation( self, db_conn: sqlite3.Connection, @@ -96,6 +101,7 @@ def create_full_annotation( c.table_annotation_id, c.constraint_type, c.name, + c.description, c.expression, c.ref_table, c.on_update_action, @@ -107,8 +113,8 @@ def create_full_annotation( ] cursor.executemany( """ - INSERT INTO table_constraint (id, table_annotation_id, constraint_type, name, expression, ref_table, on_update_action, on_delete_action, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO table_constraint (id, table_annotation_id, constraint_type, name, description, expression, ref_table, on_update_action, on_delete_action, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, constraint_data, ) @@ -155,6 +161,20 @@ def create_full_annotation( index_column_data, ) + def update_db_profile_annotation_id( + self, db_conn: sqlite3.Connection, db_profile_id: str, annotation_id: str + ) -> None: + """ + 주어진 db_profile_id에 해당하는 레코드의 annotation_id를 업데이트합니다. + - 서비스 계층에서 트랜잭션을 관리하므로 connection을 인자로 받습니다. + - 실패 시 sqlite3.Error를 발생시킵니다. + """ + cursor = db_conn.cursor() + cursor.execute( + "UPDATE db_profile SET annotation_id = ? WHERE id = ?", + (annotation_id, db_profile_id), + ) + def find_full_annotation_by_id(self, annotation_id: str) -> FullAnnotationResponse | None: """ annotationId로 전체 어노테이션 상세 정보를 조회합니다. @@ -182,18 +202,24 @@ def find_full_annotation_by_id(self, annotation_id: str) -> FullAnnotationRespon # 컬럼 정보 cursor.execute( - "SELECT id, column_name, description FROM column_annotation WHERE table_annotation_id = ?", + "SELECT id, column_name, description, data_type, is_nullable, default_value FROM column_annotation WHERE table_annotation_id = ?", (table_id,), ) - columns = [ColumnAnnotationDetail.model_validate(dict(c)) for c in cursor.fetchall()] + columns = [] + for c in cursor.fetchall(): + c_dict = dict(c) + c_dict["is_nullable"] = ( + bool(c_dict["is_nullable"]) if c_dict.get("is_nullable") is not None else None + ) + columns.append(ColumnAnnotationDetail.model_validate(c_dict)) # 제약조건 정보 cursor.execute( """ - SELECT tc.name, tc.constraint_type, ca.column_name + SELECT tc.name, tc.constraint_type, tc.description, ca.column_name FROM table_constraint tc - JOIN constraint_column cc ON tc.id = cc.constraint_id - JOIN column_annotation ca ON cc.column_annotation_id = ca.id + LEFT JOIN constraint_column cc ON tc.id = cc.constraint_id + LEFT JOIN column_annotation ca ON cc.column_annotation_id = ca.id WHERE tc.table_annotation_id = ? """, (table_id,), @@ -201,10 +227,16 @@ def find_full_annotation_by_id(self, annotation_id: str) -> FullAnnotationRespon constraint_map = {} for row in cursor.fetchall(): if row["name"] not in constraint_map: - constraint_map[row["name"]] = {"type": row["constraint_type"], "columns": []} - constraint_map[row["name"]]["columns"].append(row["column_name"]) + constraint_map[row["name"]] = { + "type": row["constraint_type"], + "columns": [], + "description": row["description"], + } + if row["column_name"]: + constraint_map[row["name"]]["columns"].append(row["column_name"]) constraints = [ - ConstraintDetail(name=k, type=v["type"], columns=v["columns"]) for k, v in constraint_map.items() + ConstraintDetail(name=k, type=v["type"], columns=v["columns"], description=v["description"]) + for k, v in constraint_map.items() ] # 인덱스 정보 diff --git a/app/repository/user_db_repository.py b/app/repository/user_db_repository.py index b437572..d30d7c7 100644 --- a/app/repository/user_db_repository.py +++ b/app/repository/user_db_repository.py @@ -3,6 +3,7 @@ import oracledb +from app.core.enum.db_driver import DBTypesEnum from app.core.exceptions import APIException from app.core.status import CommonCode from app.core.utils import get_db_path @@ -141,8 +142,10 @@ def find_profile(self, sql: str, data: tuple) -> AllDBProfileInfo: row = cursor.fetchone() if not row: - raise APIException(CommonCode.NO_SEARCH_DATA) + raise APIException(CommonCode.NO_DB_PROFILE_FOUND) return AllDBProfileInfo(**dict(row)) + except APIException: + raise except sqlite3.Error as e: raise APIException(CommonCode.FAIL_FIND_PROFILE) from e except Exception as e: @@ -209,50 +212,12 @@ def find_columns( connection = self._connect(driver_module, **kwargs) cursor = connection.cursor() - if db_type == "sqlite": - # SQLite는 PRAGMA를 직접 실행 - pragma_sql = f"PRAGMA table_info('{table_name}')" - cursor.execute(pragma_sql) - columns_raw = cursor.fetchall() - columns = [ - ColumnInfo( - name=c[1], - type=c[2], - nullable=(c[3] == 0), # notnull == 0 means nullable - default=c[4], - comment=None, - is_pk=(c[5] == 1), - ) - for c in columns_raw - ] + if db_type == DBTypesEnum.sqlite.name: + columns = self._find_columns_for_sqlite(cursor, table_name) + elif db_type == DBTypesEnum.postgresql.name: + columns = self._find_columns_for_postgresql(cursor, schema_name, table_name) else: - if "%s" in column_query or "?" in column_query: - cursor.execute(column_query, (schema_name, table_name)) - elif ":owner" in column_query and ":table" in column_query: - owner_bind = schema_name.upper() if schema_name else schema_name - table_bind = table_name.upper() if table_name else table_name - try: - cursor.execute(column_query, {"owner": owner_bind, "table": table_bind}) - except Exception: - try: - pos_query = column_query.replace(":owner", ":1").replace(":table", ":2") - cursor.execute(pos_query, [owner_bind, table_bind]) - except Exception as e: - raise APIException(CommonCode.FAIL) from e - else: - cursor.execute(column_query) - - columns = [ - ColumnInfo( - name=c[0], - type=c[1], - nullable=(c[2] in ["YES", "Y", True]), - default=c[3], - comment=c[4] if len(c) > 4 else None, - is_pk=(c[5] in ["PRI", True] if len(c) > 5 else False), - ) - for c in cursor.fetchall() - ] + columns = self._find_columns_for_general(cursor, column_query, schema_name, table_name) return ColumnListResult(is_successful=True, code=CommonCode.SUCCESS_FIND_COLUMNS, columns=columns) except Exception: @@ -261,94 +226,363 @@ def find_columns( if connection: connection.close() + def _find_columns_for_sqlite(self, cursor: Any, table_name: str) -> list[ColumnInfo]: + pragma_sql = f"PRAGMA table_info('{table_name}')" + cursor.execute(pragma_sql) + columns_raw = cursor.fetchall() + # SQLite는 pragma에서 순서(cid)를 반환하지만, ordinal_position은 1부터 시작하는 표준이므로 +1 + return [ + ColumnInfo( + name=c[1], + type=c[2], + nullable=(c[3] == 0), + default=c[4], + comment=None, + is_pk=(c[5] == 1), + ordinal_position=c[0] + 1, + ) + for c in columns_raw + ] + + def _find_columns_for_postgresql(self, cursor: Any, schema_name: str, table_name: str) -> list[ColumnInfo]: + sql = """ + SELECT + c.column_name, + c.udt_name, + c.is_nullable, + c.column_default, + c.ordinal_position, + (SELECT pg_catalog.col_description(cls.oid, c.dtd_identifier::int) + FROM pg_catalog.pg_class cls + JOIN pg_catalog.pg_namespace n ON n.oid = cls.relnamespace + WHERE cls.relname = c.table_name AND n.nspname = c.table_schema) as comment, + CASE WHEN kcu.column_name IS NOT NULL THEN TRUE ELSE FALSE END as is_pk + FROM + information_schema.columns c + LEFT JOIN information_schema.key_column_usage kcu + ON c.table_schema = kcu.table_schema + AND c.table_name = kcu.table_name + AND c.column_name = kcu.column_name + AND kcu.constraint_name IN ( + SELECT constraint_name + FROM information_schema.table_constraints + WHERE table_schema = %s + AND table_name = %s + AND constraint_type = 'PRIMARY KEY' + ) + WHERE + c.table_schema = %s AND c.table_name = %s + ORDER BY + c.ordinal_position; + """ + cursor.execute(sql, (schema_name, table_name, schema_name, table_name)) + columns_raw = cursor.fetchall() + return [ + ColumnInfo( + name=c[0], + type=c[1], + nullable=(c[2] == "YES"), + default=c[3], + ordinal_position=c[4], + comment=c[5], + is_pk=c[6], + ) + for c in columns_raw + ] + + def _find_columns_for_general( + self, cursor: Any, column_query: str, schema_name: str, table_name: str + ) -> list[ColumnInfo]: + if "%s" in column_query or "?" in column_query: + cursor.execute(column_query, (schema_name, table_name)) + elif ":owner" in column_query and ":table" in column_query: + owner_bind = schema_name.upper() if schema_name else schema_name + table_bind = table_name.upper() if table_name else table_name + try: + cursor.execute(column_query, {"owner": owner_bind, "table": table_bind}) + except Exception: + try: + pos_query = column_query.replace(":owner", ":1").replace(":table", ":2") + cursor.execute(pos_query, [owner_bind, table_bind]) + except Exception as e: + raise APIException(CommonCode.FAIL) from e + else: + cursor.execute(column_query) + + columns = [] + for c in cursor.fetchall(): + data_type = c[1] + if c[6] is not None: + data_type = f"{data_type}({c[6]})" + elif c[7] is not None and c[8] is not None: + data_type = f"{data_type}({c[7]}, {c[8]})" + + columns.append( + ColumnInfo( + name=c[0], + type=data_type, + nullable=(c[2] in ["YES", "Y", True]), + default=c[3], + comment=c[4] if len(c) > 4 else None, + is_pk=(c[5] in ["PRI", True] if len(c) > 5 else False), + ) + ) + return columns + def find_constraints( - self, driver_module: Any, db_type: str, table_name: str, **kwargs: Any + self, driver_module: Any, db_type: str, schema_name: str, table_name: str, **kwargs: Any ) -> list[ConstraintInfo]: """ 테이블의 제약 조건 정보를 조회합니다. - - 현재는 SQLite만 지원합니다. + - 현재는 SQLite, PostgreSQL만 지원합니다. - 실패 시 DB 드라이버의 예외를 직접 발생시킵니다. """ connection = None try: connection = self._connect(driver_module, **kwargs) cursor = connection.cursor() - constraints = [] - - if db_type == "sqlite": - # Foreign Key 제약 조건 조회 - fk_list_sql = f"PRAGMA foreign_key_list('{table_name}')" - cursor.execute(fk_list_sql) - fks = cursor.fetchall() - - # Foreign Key 정보를 그룹화 - fk_groups = {} - for fk in fks: - fk_id = fk[0] - if fk_id not in fk_groups: - fk_groups[fk_id] = {"referenced_table": fk[2], "columns": [], "referenced_columns": []} - fk_groups[fk_id]["columns"].append(fk[3]) - fk_groups[fk_id]["referenced_columns"].append(fk[4]) - - for _, group in fk_groups.items(): - constraints.append( - ConstraintInfo( - name=f"fk_{table_name}_{'_'.join(group['columns'])}", - type="FOREIGN KEY", - columns=group["columns"], - referenced_table=group["referenced_table"], - referenced_columns=group["referenced_columns"], - ) - ) - - # 다른 DB 타입에 대한 제약 조건 조회 로직 추가 가능 - # elif db_type == "postgresql": ... - - return constraints + + if db_type == DBTypesEnum.sqlite.name: + return self._find_constraints_for_sqlite(cursor, table_name) + elif db_type == DBTypesEnum.postgresql.name: + return self._find_constraints_for_postgresql(cursor, schema_name, table_name) + # elif db_type == ...: + return [] finally: if connection: connection.close() - def find_indexes(self, driver_module: Any, db_type: str, table_name: str, **kwargs: Any) -> list[IndexInfo]: + def _find_constraints_for_sqlite(self, cursor: Any, table_name: str) -> list[ConstraintInfo]: + constraints = [] + fk_list_sql = f"PRAGMA foreign_key_list('{table_name}')" + cursor.execute(fk_list_sql) + fks = cursor.fetchall() + + # Foreign Key 정보를 그룹화 + fk_groups = {} + for fk in fks: + fk_id = fk[0] + if fk_id not in fk_groups: + fk_groups[fk_id] = {"referenced_table": fk[2], "columns": [], "referenced_columns": []} + fk_groups[fk_id]["columns"].append(fk[3]) + fk_groups[fk_id]["referenced_columns"].append(fk[4]) + + for _, group in fk_groups.items(): + constraints.append( + ConstraintInfo( + name=f"fk_{table_name}_{'_'.join(group['columns'])}", + type="FOREIGN KEY", + columns=group["columns"], + referenced_table=group["referenced_table"], + referenced_columns=group["referenced_columns"], + ) + ) + return constraints + + def _find_constraints_for_postgresql(self, cursor: Any, schema_name: str, table_name: str) -> list[ConstraintInfo]: + sql = """ + SELECT + tc.constraint_name, + tc.constraint_type, + kcu.column_name, + rc.update_rule, + rc.delete_rule, + ccu.table_name AS foreign_table_name, + ccu.column_name AS foreign_column_name, + chk.check_clause + FROM + information_schema.table_constraints tc + LEFT JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema AND tc.table_name = kcu.table_name + LEFT JOIN information_schema.referential_constraints rc + ON tc.constraint_name = rc.constraint_name AND tc.table_schema = rc.constraint_schema + LEFT JOIN information_schema.constraint_column_usage ccu + ON rc.unique_constraint_name = ccu.constraint_name AND rc.unique_constraint_schema = ccu.table_schema + LEFT JOIN information_schema.check_constraints chk + ON tc.constraint_name = chk.constraint_name AND tc.table_schema = chk.constraint_schema + WHERE + tc.table_schema = %s AND tc.table_name = %s; + """ + cursor.execute(sql, (schema_name, table_name)) + raw_constraints = cursor.fetchall() + + constraint_map = {} + for row in raw_constraints: + # Filter out 'NOT NULL' constraints which are handled by `is_nullable` in column info + const_type = row[1] + check_clause = row[7] + if const_type == "CHECK" and check_clause and "IS NOT NULL" in check_clause.upper(): + continue + + (name, _, column, on_update, on_delete, ref_table, ref_column, check_expr) = row + if name not in constraint_map: + constraint_map[name] = { + "type": const_type, + "columns": [], + "referenced_table": ref_table, + "referenced_columns": [], + "check_expression": check_expr, + "on_update": on_update, + "on_delete": on_delete, + } + if column and column not in constraint_map[name]["columns"]: + constraint_map[name]["columns"].append(column) + if ref_column and ref_column not in constraint_map[name]["referenced_columns"]: + constraint_map[name]["referenced_columns"].append(ref_column) + + return [ + ConstraintInfo( + name=name, + type=data["type"], + columns=data["columns"], + referenced_table=data["referenced_table"], + referenced_columns=data["referenced_columns"] if data["referenced_columns"] else None, + check_expression=data["check_expression"], + on_update=data["on_update"], + on_delete=data["on_delete"], + ) + for name, data in constraint_map.items() + ] + + def find_indexes( + self, driver_module: Any, db_type: str, schema_name: str, table_name: str, **kwargs: Any + ) -> list[IndexInfo]: """ 테이블의 인덱스 정보를 조회합니다. - - 현재는 SQLite만 지원합니다. - 실패 시 DB 드라이버의 예외를 직접 발생시킵니다. """ connection = None try: connection = self._connect(driver_module, **kwargs) cursor = connection.cursor() - indexes = [] - - if db_type == "sqlite": - index_list_sql = f"PRAGMA index_list('{table_name}')" - cursor.execute(index_list_sql) - raw_indexes = cursor.fetchall() - for idx in raw_indexes: - index_name = idx[1] - is_unique = idx[2] == 1 - - # "sqlite_autoindex_"로 시작하는 인덱스는 PK에 의해 자동 생성된 것이므로 제외 - if index_name.startswith("sqlite_autoindex_"): - continue - - index_info_sql = f"PRAGMA index_info('{index_name}')" - cursor.execute(index_info_sql) - index_columns = [row[2] for row in cursor.fetchall()] - - if index_columns: - indexes.append(IndexInfo(name=index_name, columns=index_columns, is_unique=is_unique)) + if db_type == DBTypesEnum.sqlite.name: + return self._find_indexes_for_sqlite(cursor, table_name) + elif db_type == DBTypesEnum.postgresql.name: + return self._find_indexes_for_postgresql(cursor, schema_name, table_name) + # elif db_type == ...: + return [] + finally: + if connection: + connection.close() - # 다른 DB 타입에 대한 인덱스 조회 로직 추가 가능 - # elif db_type == "postgresql": ... + def _find_indexes_for_sqlite(self, cursor: Any, table_name: str) -> list[IndexInfo]: + indexes = [] + index_list_sql = f"PRAGMA index_list('{table_name}')" + cursor.execute(index_list_sql) + raw_indexes = cursor.fetchall() + + for idx in raw_indexes: + index_name = idx[1] + is_unique = idx[2] == 1 + + # "sqlite_autoindex_"로 시작하는 인덱스는 PK에 의해 자동 생성된 것이므로 제외 + if index_name.startswith("sqlite_autoindex_"): + continue + + index_info_sql = f"PRAGMA index_info('{index_name}')" + cursor.execute(index_info_sql) + index_columns = [row[2] for row in cursor.fetchall()] + + if index_columns: + indexes.append(IndexInfo(name=index_name, columns=index_columns, is_unique=is_unique)) + return indexes + + def _find_indexes_for_postgresql(self, cursor: Any, schema_name: str, table_name: str) -> list[IndexInfo]: + sql = """ + SELECT + i.relname as index_name, + a.attname as column_name, + ix.indisunique as is_unique, + ix.indisprimary as is_primary + FROM + pg_class t, + pg_class i, + pg_index ix, + pg_attribute a, + pg_namespace n + WHERE + t.oid = ix.indrelid + and i.oid = ix.indexrelid + and a.attrelid = t.oid + and a.attnum = ANY(ix.indkey) + and t.relkind = 'r' + and n.oid = t.relnamespace + and n.nspname = %s + and t.relname = %s + ORDER BY + i.relname, a.attnum; + """ + cursor.execute(sql, (schema_name, table_name)) + raw_indexes = cursor.fetchall() + + index_map = {} + for row in raw_indexes: + index_name, column_name, is_unique, is_primary = row + if is_primary: # Exclude indexes created for PRIMARY KEY constraints + continue + if index_name not in index_map: + index_map[index_name] = {"columns": [], "is_unique": is_unique} + index_map[index_name]["columns"].append(column_name) + + return [ + IndexInfo(name=name, columns=data["columns"], is_unique=data["is_unique"]) + for name, data in index_map.items() + ] + + def find_sample_rows( + self, driver_module: Any, db_type: str, schema_name: str, table_names: list[str], **kwargs: Any + ) -> dict[str, list[dict[str, Any]]]: + """ + 주어진 테이블 목록에 대해 상위 3개의 샘플 행을 조회합니다. + - 실패 시 DB 드라이버의 예외를 직접 발생시킵니다. + """ + connection = None + try: + connection = self._connect(driver_module, **kwargs) + cursor = connection.cursor() - return indexes + if db_type == DBTypesEnum.sqlite.name: + return self._find_sample_rows_for_sqlite(cursor, table_names) + elif db_type == DBTypesEnum.postgresql.name: + return self._find_sample_rows_for_postgresql(cursor, schema_name, table_names) + # elif db_type == ...: + return {table_name: [] for table_name in table_names} finally: if connection: connection.close() + def _find_sample_rows_for_sqlite(self, cursor: Any, table_names: list[str]) -> dict[str, list[dict[str, Any]]]: + sample_rows_map = {} + for table_name in table_names: + try: + # 컬럼명 조회를 위해 PRAGMA 사용 + cursor.execute(f"PRAGMA table_info('{table_name}')") + columns = [row[1] for row in cursor.fetchall()] + + # 데이터 조회 + cursor.execute(f'SELECT * FROM "{table_name}" LIMIT 3') + rows = cursor.fetchall() + sample_rows_map[table_name] = [dict(zip(columns, row, strict=False)) for row in rows] + except Exception: + sample_rows_map[table_name] = [] # 오류 발생 시 빈 리스트 할당 + return sample_rows_map + + def _find_sample_rows_for_postgresql( + self, cursor: Any, schema_name: str, table_names: list[str] + ) -> dict[str, list[dict[str, Any]]]: + sample_rows_map = {} + for table_name in table_names: + try: + # PostgreSQL은 cursor.description을 통해 컬럼명을 바로 얻을 수 있음 + cursor.execute(f'SELECT * FROM "{schema_name}"."{table_name}" LIMIT 3') + columns = [desc[0] for desc in cursor.description] + rows = cursor.fetchall() + sample_rows_map[table_name] = [dict(zip(columns, row, strict=False)) for row in rows] + except Exception: + sample_rows_map[table_name] = [] + return sample_rows_map + # ───────────────────────────── # DB 연결 메서드 # ───────────────────────────── diff --git a/app/schemas/annotation/ai_model.py b/app/schemas/annotation/ai_model.py new file mode 100644 index 0000000..f902699 --- /dev/null +++ b/app/schemas/annotation/ai_model.py @@ -0,0 +1,64 @@ +from typing import Any + +from pydantic import BaseModel, Field + + +class AIColumnInfo(BaseModel): + """AI 요청을 위한 컬럼 정보 모델""" + + column_name: str = Field(..., description="컬럼 이름") + data_type: str = Field(..., description="데이터 타입") + is_pk: bool = Field(False, description="기본 키(Primary Key) 여부") + is_nullable: bool = Field(..., description="NULL 허용 여부") + default_value: Any | None = Field(None, description="기본값") + + +class AIConstraintInfo(BaseModel): + """AI 요청을 위한 제약 조건 정보 모델 (FK 제외)""" + + name: str | None = Field(None, description="제약 조건 이름") + type: str = Field(..., description="제약 조건 타입 (PRIMARY KEY, UNIQUE, CHECK)") + columns: list[str] = Field(..., description="제약 조건에 포함된 컬럼 목록") + check_expression: str | None = Field(None, description="CHECK 제약 조건 표현식") + + +class AIIndexInfo(BaseModel): + """AI 요청을 위한 인덱스 정보 모델""" + + name: str | None = Field(None, description="인덱스 이름") + columns: list[str] = Field(..., description="인덱스에 포함된 컬럼 목록") + is_unique: bool = Field(False, description="고유 인덱스 여부") + + +class AITableInfo(BaseModel): + """AI 요청을 위한 테이블 정보 모델""" + + table_name: str = Field(..., description="테이블 이름") + columns: list[AIColumnInfo] = Field(..., description="컬럼 목록") + constraints: list[AIConstraintInfo] = Field([], description="제약 조건 목록 (FK 제외)") + indexes: list[AIIndexInfo] = Field([], description="인덱스 목록") + sample_rows: list[dict[str, Any]] = Field([], description="테이블 샘플 데이터") + + +class AIRelationship(BaseModel): + """AI 요청을 위한 관계(FK) 정보 모델""" + + from_table: str = Field(..., description="관계를 시작하는 테이블") + from_columns: list[str] = Field(..., description="관계를 시작하는 컬럼") + to_table: str = Field(..., description="관계를 맺는 대상 테이블") + to_columns: list[str] = Field(..., description="관계를 맺는 대상 컬럼") + + +class AIDatabaseInfo(BaseModel): + """AI 요청을 위한 데이터베이스 정보 모델""" + + database_name: str = Field(..., description="데이터베이스 이름") + tables: list[AITableInfo] = Field(..., description="테이블 목록") + relationships: list[AIRelationship] = Field([], description="관계(FK) 목록") + + +class AIAnnotationRequest(BaseModel): + """AI 어노테이션 생성 요청 최상위 모델""" + + dbms_type: str = Field(..., description="DBMS 종류") + databases: list[AIDatabaseInfo] = Field(..., description="데이터베이스 목록") diff --git a/app/schemas/annotation/db_model.py b/app/schemas/annotation/db_model.py index 667765b..8a73aa9 100644 --- a/app/schemas/annotation/db_model.py +++ b/app/schemas/annotation/db_model.py @@ -39,6 +39,7 @@ class TableConstraintInDB(AnnotationBase): table_annotation_id: str constraint_type: ConstraintTypeEnum name: str | None = None + description: str | None = None expression: str | None = None ref_table: str | None = None on_update_action: str | None = None diff --git a/app/schemas/annotation/response_model.py b/app/schemas/annotation/response_model.py index 5602e0e..5ec4556 100644 --- a/app/schemas/annotation/response_model.py +++ b/app/schemas/annotation/response_model.py @@ -10,6 +10,9 @@ class ColumnAnnotationDetail(BaseModel): id: str column_name: str description: str | None = None + data_type: str | None = None + is_nullable: bool | None = None + default_value: str | None = None class ConstraintDetail(BaseModel): diff --git a/app/schemas/user_db/db_profile_model.py b/app/schemas/user_db/db_profile_model.py index 49598d8..32d666a 100644 --- a/app/schemas/user_db/db_profile_model.py +++ b/app/schemas/user_db/db_profile_model.py @@ -57,10 +57,12 @@ def _is_empty(value: Any | None) -> bool: class UpdateOrCreateDBProfile(DBProfileInfo): id: str | None = Field(None, description="DB Key 값") view_name: str | None = Field(None, description="DB 노출명") + annotation_id: str | None = Field(None, description="연결된 어노테이션 ID") class AllDBProfileInfo(DBProfileInfo): id: str | None = Field(..., description="DB Key 값") view_name: str | None = Field(None, description="DB 노출명") + annotation_id: str | None = Field(None, description="연결된 어노테이션 ID") created_at: datetime = Field(..., description="profile 저장일") updated_at: datetime = Field(..., description="profile 수정일") diff --git a/app/schemas/user_db/result_model.py b/app/schemas/user_db/result_model.py index 64d9190..d2a5b2a 100644 --- a/app/schemas/user_db/result_model.py +++ b/app/schemas/user_db/result_model.py @@ -30,6 +30,7 @@ class DBProfile(BaseModel): name: str | None username: str | None view_name: str | None + annotation_id: str | None = None created_at: datetime updated_at: datetime @@ -53,6 +54,7 @@ class ColumnInfo(BaseModel): default: Any | None = Field(None, description="기본값") comment: str | None = Field(None, description="코멘트") is_pk: bool = Field(False, description="기본 키(Primary Key) 여부") + ordinal_position: int | None = Field(None, description="컬럼 순서") class ConstraintInfo(BaseModel): @@ -64,6 +66,8 @@ class ConstraintInfo(BaseModel): # FOREIGN KEY 관련 필드 referenced_table: str | None = Field(None, description="참조하는 테이블 (FK)") referenced_columns: list[str] | None = Field(None, description="참조하는 테이블의 컬럼 (FK)") + on_update: str | None = Field(None, description="UPDATE 시 동작 (FK)") + on_delete: str | None = Field(None, description="DELETE 시 동작 (FK)") # CHECK 관련 필드 check_expression: str | None = Field(None, description="CHECK 제약 조건 표현식") diff --git a/app/services/annotation_service.py b/app/services/annotation_service.py index b217586..e10fe0b 100644 --- a/app/services/annotation_service.py +++ b/app/services/annotation_service.py @@ -10,6 +10,15 @@ from app.core.status import CommonCode from app.core.utils import generate_prefixed_uuid, get_db_path from app.repository.annotation_repository import AnnotationRepository, annotation_repository +from app.schemas.annotation.ai_model import ( + AIAnnotationRequest, + AIColumnInfo, + AIConstraintInfo, + AIDatabaseInfo, + AIIndexInfo, + AIRelationship, + AITableInfo, +) from app.schemas.annotation.db_model import ( ColumnAnnotationInDB, ConstraintColumnInDB, @@ -36,40 +45,59 @@ class AnnotationService: def __init__( self, repository: AnnotationRepository = annotation_repository, user_db_serv: UserDbService = user_db_service ): + """ + AnnotationService를 초기화합니다. + + Args: + repository (AnnotationRepository): 어노테이션 레포지토리 의존성 주입. + user_db_serv (UserDbService): 사용자 DB 서비스 의존성 주입. + """ self.repository = repository self.user_db_service = user_db_serv async def create_annotation(self, request: AnnotationCreateRequest) -> FullAnnotationResponse: """ 어노테이션 생성을 위한 전체 프로세스를 관장합니다. - 1. DB 프로필 및 전체 스키마 정보 조회 - 2. TODO: AI 서버에 요청 (현재는 Mock 데이터 사용) - 3. 트랜잭션 내에서 전체 어노테이션 정보 저장 + 1. DB 프로필, 전체 스키마 정보, 샘플 데이터 조회 + 2. AI 서버에 요청할 데이터 모델 생성 + 3. TODO: AI 서버에 요청 (현재는 Mock 데이터 사용) + 4. 트랜잭션 내에서 전체 어노테이션 정보 저장 및 DB 프로필 업데이트 """ try: request.validate() except ValueError as e: raise APIException(CommonCode.INVALID_ANNOTATION_REQUEST, detail=str(e)) from e - # 1. DB 프로필 및 전체 스키마 정보 조회 + # 1. DB 프로필, 전체 스키마 정보, 샘플 데이터 조회 db_profile = self.user_db_service.find_profile(request.db_profile_id) full_schema_info = self.user_db_service.get_full_schema_info(db_profile) + sample_rows = self.user_db_service.get_sample_rows(db_profile, full_schema_info) + + # 2. AI 서버에 요청할 데이터 모델 생성 + ai_request_body = self._prepare_ai_request_body(db_profile, full_schema_info, sample_rows) + print(ai_request_body.model_dump_json(indent=2)) - # 2. AI 서버에 요청 (현재는 Mock 데이터 사용) - ai_response = await self._request_annotation_to_ai_server(full_schema_info) + # 3. AI 서버에 요청 (현재는 Mock 데이터 사용) + ai_response = await self._request_annotation_to_ai_server(ai_request_body) - # 3. 트랜잭션 내에서 전체 어노테이션 정보 저장 + # 4. 트랜잭션 내에서 전체 어노테이션 정보 저장 및 DB 프로필 업데이트 db_path = get_db_path() conn = None try: conn = sqlite3.connect(str(db_path), timeout=10) conn.execute("BEGIN") - db_models = self._transform_ai_response_to_db_models(ai_response, db_profile, request.db_profile_id) + db_models = self._transform_ai_response_to_db_models( + ai_response, db_profile, request.db_profile_id, full_schema_info + ) self.repository.create_full_annotation(db_conn=conn, **db_models) - conn.commit() annotation_id = db_models["db_annotation"].id + self.repository.update_db_profile_annotation_id( + db_conn=conn, db_profile_id=request.db_profile_id, annotation_id=annotation_id + ) + + conn.commit() except sqlite3.Error as e: if conn: @@ -81,22 +109,111 @@ async def create_annotation(self, request: AnnotationCreateRequest) -> FullAnnot return self.get_full_annotation(annotation_id) + def get_annotation_by_db_profile_id(self, db_profile_id: str) -> FullAnnotationResponse: + """ + db_profile_id를 기반으로 완전한 어노테이션 정보를 조회합니다. + """ + db_profile = self.user_db_service.find_profile(db_profile_id) + if not db_profile.annotation_id: + raise APIException(CommonCode.NO_ANNOTATION_FOR_PROFILE) + + return self.get_full_annotation(db_profile.annotation_id) + + def _prepare_ai_request_body( + self, + db_profile: AllDBProfileInfo, + full_schema_info: list[UserDBTableInfo], + sample_rows: dict[str, list[dict[str, Any]]], + ) -> AIAnnotationRequest: + """ + AI 서버에 보낼 요청 본문을 Pydantic 모델로 생성합니다. + """ + ai_tables = [] + ai_relationships = [] + + for table_info in full_schema_info: + # FK 제약조건을 분리하여 relationships 목록 생성 + non_fk_constraints = [] + for const in table_info.constraints: + if const.type == "FOREIGN KEY" and const.referenced_table and const.referenced_columns: + ai_relationships.append( + AIRelationship( + from_table=table_info.name, + from_columns=const.columns, + to_table=const.referenced_table, + to_columns=const.referenced_columns, + ) + ) + else: + non_fk_constraints.append( + AIConstraintInfo( + name=const.name, + type=const.type, + columns=const.columns, + check_expression=const.check_expression, + ) + ) + + ai_table = AITableInfo( + table_name=table_info.name, + columns=[ + AIColumnInfo( + column_name=col.name, + data_type=col.type, + is_pk=col.is_pk, + is_nullable=col.nullable, + default_value=col.default, + ) + for col in table_info.columns + ], + constraints=non_fk_constraints, + indexes=[ + AIIndexInfo(name=idx.name, columns=idx.columns, is_unique=idx.is_unique) + for idx in table_info.indexes + ], + sample_rows=sample_rows.get(table_info.name, []), + ) + ai_tables.append(ai_table) + + ai_database = AIDatabaseInfo( + database_name=db_profile.name or db_profile.username, tables=ai_tables, relationships=ai_relationships + ) + + return AIAnnotationRequest(dbms_type=db_profile.type, databases=[ai_database]) + def _transform_ai_response_to_db_models( - self, ai_response: dict[str, Any], db_profile: AllDBProfileInfo, db_profile_id: str + self, + ai_response: dict[str, Any], + db_profile: AllDBProfileInfo, + db_profile_id: str, + full_schema_info: list[UserDBTableInfo], ) -> dict[str, Any]: + """ + AI 서버의 응답을 받아서 DB에 저장할 수 있는 모델 딕셔너리로 변환합니다. + """ now = datetime.now() annotation_id = generate_prefixed_uuid(DBSaveIdEnum.database_annotation.value) + # 원본 스키마 정보를 쉽게 조회할 수 있도록 룩업 테이블 생성 + schema_lookup: dict[str, UserDBTableInfo] = {table.name: table for table in full_schema_info} + db_anno = DatabaseAnnotationInDB( id=annotation_id, db_profile_id=db_profile_id, - database_name=db_profile.name, + database_name=db_profile.name or db_profile.username, description=ai_response.get("database_annotation"), created_at=now, updated_at=now, ) - table_annos, col_annos, constraint_annos, constraint_col_annos, index_annos, index_col_annos = ( + ( + all_table_annos, + all_col_annos, + all_constraint_annos, + all_constraint_col_annos, + all_index_annos, + all_index_col_annos, + ) = ( [], [], [], @@ -106,92 +223,184 @@ def _transform_ai_response_to_db_models( ) for tbl_data in ai_response.get("tables", []): - table_id = generate_prefixed_uuid(DBSaveIdEnum.table_annotation.value) - table_annos.append( - TableAnnotationInDB( - id=table_id, - database_annotation_id=annotation_id, - table_name=tbl_data["table_name"], - description=tbl_data.get("annotation"), + original_table = schema_lookup.get(tbl_data["table_name"]) + if not original_table: + continue + + ( + table_anno, + col_annos, + constraint_annos, + constraint_col_annos, + index_annos, + index_col_annos, + ) = self._create_annotations_for_table(tbl_data, original_table, annotation_id, now) + + all_table_annos.append(table_anno) + all_col_annos.extend(col_annos) + all_constraint_annos.extend(constraint_annos) + all_constraint_col_annos.extend(constraint_col_annos) + all_index_annos.extend(index_annos) + all_index_col_annos.extend(index_col_annos) + + return { + "db_annotation": db_anno, + "table_annotations": all_table_annos, + "column_annotations": all_col_annos, + "constraint_annotations": all_constraint_annos, + "constraint_column_annotations": all_constraint_col_annos, + "index_annotations": all_index_annos, + "index_column_annotations": all_index_col_annos, + } + + def _create_annotations_for_table( + self, + tbl_data: dict[str, Any], + original_table: UserDBTableInfo, + database_annotation_id: str, + now: datetime, + ) -> tuple: + """ + 단일 테이블에 대한 모든 하위 어노테이션(컬럼, 제약조건, 인덱스)을 생성합니다. + """ + table_id = generate_prefixed_uuid(DBSaveIdEnum.table_annotation.value) + table_anno = TableAnnotationInDB( + id=table_id, + database_annotation_id=database_annotation_id, + table_name=original_table.name, + description=tbl_data.get("annotation"), + created_at=now, + updated_at=now, + ) + + col_map = { + col.name: generate_prefixed_uuid(DBSaveIdEnum.column_annotation.value) for col in original_table.columns + } + + col_annos = self._process_columns(tbl_data, original_table, table_id, col_map, now) + constraint_annos, constraint_col_annos = self._process_constraints( + tbl_data, original_table, table_id, col_map, now + ) + index_annos, index_col_annos = self._process_indexes(tbl_data, original_table, table_id, col_map, now) + + return table_anno, col_annos, constraint_annos, constraint_col_annos, index_annos, index_col_annos + + def _process_columns( + self, tbl_data: dict, original_table: UserDBTableInfo, table_id: str, col_map: dict, now: datetime + ) -> list[ColumnAnnotationInDB]: + """ + 테이블의 컬럼 어노테이션 모델 리스트를 생성합니다. + """ + col_annos = [] + for col_data in tbl_data.get("columns", []): + original_column = next((c for c in original_table.columns if c.name == col_data["column_name"]), None) + if not original_column: + continue + col_annos.append( + ColumnAnnotationInDB( + id=col_map[original_column.name], + table_annotation_id=table_id, + column_name=original_column.name, + data_type=original_column.type, + is_nullable=1 if original_column.nullable else 0, + default_value=original_column.default, + description=col_data.get("annotation"), + ordinal_position=original_column.ordinal_position, created_at=now, updated_at=now, ) ) + return col_annos - col_map = { - col["column_name"]: generate_prefixed_uuid(DBSaveIdEnum.column_annotation.value) - for col in tbl_data.get("columns", []) - } - - for col_data in tbl_data.get("columns", []): - col_annos.append( - ColumnAnnotationInDB( - id=col_map[col_data["column_name"]], - table_annotation_id=table_id, - column_name=col_data["column_name"], - description=col_data.get("annotation"), - created_at=now, - updated_at=now, - ) + def _process_constraints( + self, tbl_data: dict, original_table: UserDBTableInfo, table_id: str, col_map: dict, now: datetime + ) -> tuple[list[TableConstraintInDB], list[ConstraintColumnInDB]]: + """ + 테이블의 제약조건 및 제약조건 컬럼 어노테이션 모델 리스트를 생성합니다. + """ + constraint_annos, constraint_col_annos = [], [] + for const_data in tbl_data.get("constraints", []): + original_constraint = next((c for c in original_table.constraints if c.name == const_data["name"]), None) + if not original_constraint: + continue + const_id = generate_prefixed_uuid(DBSaveIdEnum.table_constraint.value) + constraint_annos.append( + TableConstraintInDB( + id=const_id, + table_annotation_id=table_id, + name=original_constraint.name, + constraint_type=ConstraintTypeEnum(original_constraint.type), + description=const_data.get("annotation"), + ref_table=original_constraint.referenced_table, + expression=original_constraint.check_expression, + on_update_action=original_constraint.on_update, + on_delete_action=original_constraint.on_delete, + created_at=now, + updated_at=now, ) - - for const_data in tbl_data.get("constraints", []): - const_id = generate_prefixed_uuid(DBSaveIdEnum.table_constraint.value) - constraint_annos.append( - TableConstraintInDB( - id=const_id, - table_annotation_id=table_id, - name=const_data["name"], - constraint_type=ConstraintTypeEnum(const_data["type"]), + ) + for i, col_name in enumerate(original_constraint.columns): + if col_name not in col_map: + continue + constraint_col_annos.append( + ConstraintColumnInDB( + id=generate_prefixed_uuid(DBSaveIdEnum.constraint_column.value), + constraint_id=const_id, + column_annotation_id=col_map[col_name], + position=i + 1, + referenced_column_name=( + original_constraint.referenced_columns[i] + if original_constraint.referenced_columns + and i < len(original_constraint.referenced_columns) + else None + ), created_at=now, updated_at=now, ) ) - for col_name in const_data.get("columns", []): - constraint_col_annos.append( - ConstraintColumnInDB( - id=generate_prefixed_uuid(DBSaveIdEnum.constraint_column.value), - constraint_id=const_id, - column_annotation_id=col_map[col_name], - created_at=now, - updated_at=now, - ) - ) + return constraint_annos, constraint_col_annos - for idx_data in tbl_data.get("indexes", []): - idx_id = generate_prefixed_uuid(DBSaveIdEnum.index_annotation.value) - index_annos.append( - IndexAnnotationInDB( - id=idx_id, - table_annotation_id=table_id, - name=idx_data["name"], - is_unique=1 if idx_data.get("is_unique") else 0, + def _process_indexes( + self, tbl_data: dict, original_table: UserDBTableInfo, table_id: str, col_map: dict, now: datetime + ) -> tuple[list[IndexAnnotationInDB], list[IndexColumnInDB]]: + """ + 테이블의 인덱스 및 인덱스 컬럼 어노테이션 모델 리스트를 생성합니다. + """ + index_annos, index_col_annos = [], [] + for idx_data in tbl_data.get("indexes", []): + original_index = next((i for i in original_table.indexes if i.name == idx_data["name"]), None) + if not original_index: + continue + idx_id = generate_prefixed_uuid(DBSaveIdEnum.index_annotation.value) + index_annos.append( + IndexAnnotationInDB( + id=idx_id, + table_annotation_id=table_id, + name=original_index.name, + is_unique=1 if original_index.is_unique else 0, + created_at=now, + updated_at=now, + ) + ) + for i, col_name in enumerate(original_index.columns): + if col_name not in col_map: + continue + index_col_annos.append( + IndexColumnInDB( + id=generate_prefixed_uuid(DBSaveIdEnum.index_column.value), + index_id=idx_id, + column_annotation_id=col_map[col_name], + position=i + 1, created_at=now, updated_at=now, ) ) - for col_name in idx_data.get("columns", []): - index_col_annos.append( - IndexColumnInDB( - id=generate_prefixed_uuid(DBSaveIdEnum.index_column.value), - index_id=idx_id, - column_annotation_id=col_map[col_name], - created_at=now, - updated_at=now, - ) - ) - - return { - "db_annotation": db_anno, - "table_annotations": table_annos, - "column_annotations": col_annos, - "constraint_annotations": constraint_annos, - "constraint_column_annotations": constraint_col_annos, - "index_annotations": index_annos, - "index_column_annotations": index_col_annos, - } + return index_annos, index_col_annos def get_full_annotation(self, annotation_id: str) -> FullAnnotationResponse: + """ + ID를 기반으로 완전한 어노테이션 정보를 조회합니다. + """ try: annotation = self.repository.find_full_annotation_by_id(annotation_id) if not annotation: @@ -201,6 +410,9 @@ def get_full_annotation(self, annotation_id: str) -> FullAnnotationResponse: raise APIException(CommonCode.FAIL_FIND_ANNOTATION) from e def delete_annotation(self, annotation_id: str) -> AnnotationDeleteResponse: + """ + ID를 기반으로 어노테이션 및 관련 하위 데이터를 모두 삭제합니다. + """ try: is_deleted = self.repository.delete_annotation_by_id(annotation_id) if not is_deleted: @@ -209,13 +421,13 @@ def delete_annotation(self, annotation_id: str) -> AnnotationDeleteResponse: except sqlite3.Error as e: raise APIException(CommonCode.FAIL_DELETE_ANNOTATION) from e - async def _request_annotation_to_ai_server(self, schema_info: list[UserDBTableInfo]) -> dict: + async def _request_annotation_to_ai_server(self, ai_request: AIAnnotationRequest) -> dict: """AI 서버에 스키마 정보를 보내고 어노테이션을 받아옵니다.""" # 우선은 목업 데이터 활용 - return self._get_mock_ai_response(schema_info) + return self._get_mock_ai_response(ai_request) # Real implementation below - # request_body = {"database_schema": {"tables": [table.model_dump() for table in schema_info]}} + # request_body = ai_request.model_dump() # async with httpx.AsyncClient() as client: # try: # response = await client.post(AI_SERVER_URL, json=request_body, timeout=60.0) @@ -226,19 +438,21 @@ async def _request_annotation_to_ai_server(self, schema_info: list[UserDBTableIn # except httpx.RequestError as e: # raise APIException(CommonCode.FAIL_AI_SERVER_CONNECTION, detail=f"AI server connection failed: {e}") from e - def _get_mock_ai_response(self, schema_info: list[UserDBTableInfo]) -> dict: + def _get_mock_ai_response(self, ai_request: AIAnnotationRequest) -> dict: """테스트를 위한 Mock AI 서버 응답 생성""" + # 요청 데이터를 기반으로 동적으로 Mock 응답을 생성하도록 수정 + db_info = ai_request.databases[0] mock_response = { - "database_annotation": "Mock: 데이터베이스 전체에 대한 설명입니다.", + "database_annotation": f"Mock: '{db_info.database_name}' 데이터베이스 전체에 대한 설명입니다.", "tables": [], "relationships": [], } - for table in schema_info: + for table in db_info.tables: mock_table = { - "table_name": table.name, - "annotation": f"Mock: '{table.name}' 테이블에 대한 설명입니다.", + "table_name": table.table_name, + "annotation": f"Mock: '{table.table_name}' 테이블에 대한 설명입니다.", "columns": [ - {"column_name": col.name, "annotation": f"Mock: '{col.name}' 컬럼에 대한 설명입니다."} + {"column_name": col.column_name, "annotation": f"Mock: '{col.column_name}' 컬럼에 대한 설명입니다."} for col in table.columns ], "constraints": [ @@ -261,6 +475,17 @@ def _get_mock_ai_response(self, schema_info: list[UserDBTableInfo]) -> dict: ], } mock_response["tables"].append(mock_table) + + for rel in db_info.relationships: + mock_response["relationships"].append( + { + "from_table": rel.from_table, + "from_columns": rel.from_columns, + "to_table": rel.to_table, + "to_columns": rel.to_columns, + "annotation": f"Mock: '{rel.from_table}'과 '{rel.to_table}'의 관계 설명.", + } + ) return mock_response diff --git a/app/services/user_db_service.py b/app/services/user_db_service.py index 3afe6fd..782f396 100644 --- a/app/services/user_db_service.py +++ b/app/services/user_db_service.py @@ -34,7 +34,12 @@ def connection_test(self, db_info: DBProfileInfo, repository: UserDbRepository = try: driver_module = self._get_driver_module(db_info.type) connect_kwargs = self._prepare_connection_args(db_info) - return repository.connection_test(driver_module, **connect_kwargs) + result = repository.connection_test(driver_module, **connect_kwargs) + if not result.is_successful: + raise APIException(result.code) + return result + except APIException: + raise except Exception as e: raise APIException(CommonCode.FAIL) from e @@ -46,11 +51,15 @@ def create_profile( """ create_db_info.id = generate_prefixed_uuid(DBSaveIdEnum.user_db.value) try: - # [수정] 쿼리와 데이터를 서비스에서 생성하여 레포지토리로 전달합니다. sql, data = self._get_create_query_and_data(create_db_info) - return repository.create_profile(sql, data, create_db_info) + result = repository.create_profile(sql, data, create_db_info) + if not result.is_successful: + raise APIException(result.code) + return result + except APIException: + raise except Exception as e: - raise APIException(CommonCode.FAIL) from e + raise APIException(CommonCode.FAIL_SAVE_PROFILE) from e def update_profile( self, update_db_info: UpdateOrCreateDBProfile, repository: UserDbRepository = user_db_repository @@ -59,33 +68,45 @@ def update_profile( DB 연결 정보를 업데이트 후 결과를 반환합니다. """ try: - # [수정] 쿼리와 데이터를 서비스에서 생성하여 레포지토리로 전달합니다. sql, data = self._get_update_query_and_data(update_db_info) - return repository.update_profile(sql, data, update_db_info) + result = repository.update_profile(sql, data, update_db_info) + if not result.is_successful: + raise APIException(result.code) + return result + except APIException: + raise except Exception as e: - raise APIException(CommonCode.FAIL) from e + raise APIException(CommonCode.FAIL_UPDATE_PROFILE) from e def delete_profile(self, profile_id: str, repository: UserDbRepository = user_db_repository) -> ChangeProfileResult: """ DB 연결 정보를 삭제 후 결과를 반환합니다. """ try: - # [수정] 쿼리와 데이터를 서비스에서 생성하여 레포지토리로 전달합니다. sql, data = self._get_delete_query_and_data(profile_id) - return repository.delete_profile(sql, data, profile_id) + result = repository.delete_profile(sql, data, profile_id) + if not result.is_successful: + raise APIException(result.code) + return result + except APIException: + raise except Exception as e: - raise APIException(CommonCode.FAIL) from e + raise APIException(CommonCode.FAIL_DELETE_PROFILE) from e def find_all_profile(self, repository: UserDbRepository = user_db_repository) -> AllDBProfileResult: """ 모든 DB 연결 정보를 반환합니다. """ try: - # [수정] 쿼리를 서비스에서 생성하여 레포지토리로 전달합니다. sql = self._get_find_all_query() - return repository.find_all_profile(sql) + result = repository.find_all_profile(sql) + if not result.is_successful: + raise APIException(result.code) + return result + except APIException: + raise except Exception as e: - raise APIException(CommonCode.FAIL) from e + raise APIException(CommonCode.FAIL_FIND_PROFILE) from e def find_profile(self, profile_id, repository: UserDbRepository = user_db_repository) -> AllDBProfileInfo: """ @@ -95,8 +116,10 @@ def find_profile(self, profile_id, repository: UserDbRepository = user_db_reposi # [수정] 쿼리와 데이터를 서비스에서 생성하여 레포지토리로 전달합니다. sql, data = self._get_find_one_query_and_data(profile_id) return repository.find_profile(sql, data) + except APIException: + raise except Exception as e: - raise APIException(CommonCode.FAIL) from e + raise APIException(CommonCode.FAIL_FIND_PROFILE) from e def find_schemas( self, db_info: AllDBProfileInfo, repository: UserDbRepository = user_db_repository @@ -152,7 +175,7 @@ def find_columns( def get_full_schema_info( self, db_info: AllDBProfileInfo, repository: UserDbRepository = user_db_repository - ) -> SchemaInfoResult: + ) -> list[TableInfo]: """ DB 프로필 정보를 받아 해당 데이터베이스의 전체 스키마 정보 (테이블, 컬럼, 제약조건, 인덱스)를 조회하여 반환합니다. @@ -192,10 +215,12 @@ def get_full_schema_info( try: constraints = repository.find_constraints( - driver_module, db_info.type, table_name, **connect_kwargs + driver_module, db_info.type, schema_name, table_name, **connect_kwargs + ) + indexes = repository.find_indexes( + driver_module, db_info.type, schema_name, table_name, **connect_kwargs ) - indexes = repository.find_indexes(driver_module, db_info.type, table_name, **connect_kwargs) - except sqlite3.Error as e: + except (sqlite3.Error, self._get_driver_module(db_info.type).Error) as e: # 레포지토리에서 발생한 DB 예외를 서비스에서 처리 raise APIException(CommonCode.FAIL_FIND_CONSTRAINTS_OR_INDEXES) from e @@ -217,6 +242,24 @@ def get_full_schema_info( # 그 외 모든 예외는 일반 실패로 처리 raise APIException(CommonCode.FAIL) from e + def get_sample_rows( + self, db_info: AllDBProfileInfo, table_infos: list[TableInfo], repository: UserDbRepository = user_db_repository + ) -> dict[str, list[dict[str, Any]]]: + """ + 테이블 정보 목록을 받아 각 테이블의 샘플 데이터를 조회하여 반환합니다. + """ + try: + driver_module = self._get_driver_module(db_info.type) + connect_kwargs = self._prepare_connection_args(db_info) + + # SQLite는 스키마 이름이 필요 없음 + schema_name = db_info.name if db_info.type != "sqlite" else "" + table_names = [table.name for table in table_infos] + + return repository.find_sample_rows(driver_module, db_info.type, schema_name, table_names, **connect_kwargs) + except Exception as e: + raise APIException(CommonCode.FAIL_FIND_SAMPLE_ROWS) from e + def _get_driver_module(self, db_type: str): """ DB 타입에 따라 동적으로 드라이버 모듈을 로드합니다. @@ -268,7 +311,7 @@ def _get_schema_query(self, db_type: str) -> str | None: if db_type == "postgresql": return """ SELECT schema_name FROM information_schema.schemata - WHERE schema_name NOT IN ('pg_catalog', 'information_schema') + WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') """ elif db_type in ["mysql", "mariadb"]: return "SELECT schema_name FROM information_schema.schemata" @@ -288,7 +331,7 @@ def _get_table_query(self, db_type: str, for_all_schemas: bool = False) -> str | """ else: return """ - SELECT table_name, table_schema FROM information_schema.tables + SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = %s """ elif db_type in ["mysql", "mariadb"]: @@ -312,11 +355,38 @@ def _get_column_query(self, db_type: str) -> str | None: db_type = db_type.lower() if db_type == "postgresql": return """ - SELECT column_name, data_type, is_nullable, column_default, table_name, table_schema - FROM information_schema.columns - WHERE table_schema NOT IN ('pg_catalog', 'information_schema') - AND table_schema = %s - AND table_name = %s + SELECT + c.column_name, + c.data_type, + c.is_nullable, + c.column_default, + pgd.description AS comment, + ( + SELECT TRUE + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + WHERE tc.constraint_type = 'PRIMARY KEY' + AND tc.table_schema = c.table_schema + AND tc.table_name = c.table_name + AND kcu.column_name = c.column_name + ) AS is_pk, + c.character_maximum_length, + c.numeric_precision, + c.numeric_scale + FROM + information_schema.columns c + LEFT JOIN + pg_catalog.pg_stat_all_tables st + ON c.table_schema = st.schemaname AND c.table_name = st.relname + LEFT JOIN + pg_catalog.pg_description pgd + ON pgd.objoid = st.relid AND pgd.objsubid = c.ordinal_position + WHERE + c.table_schema = %s AND c.table_name = %s + ORDER BY + c.ordinal_position; """ elif db_type in ["mysql", "mariadb"]: return """