diff --git a/app/adapter/request_batch.py b/app/adapter/request_batch.py deleted file mode 100644 index a826013..0000000 --- a/app/adapter/request_batch.py +++ /dev/null @@ -1,83 +0,0 @@ -import asyncio -from typing import List, Optional - -from openai import APITimeoutError # AsyncOpenAI 타입 힌트용 - -from app.client.oepn_ai import get_gpt_client # AsyncOpenAI를 반환한다고 가정 -from app.util.logger import logger - - -async def request_responses_output_text(gpt_request: dict) -> str: - """Responses API로 단건 요청을 비동기로 전송하고 텍스트만 추출한다.""" - resp = await get_gpt_client().responses.create(**gpt_request) - - text = getattr(resp, "output_text", None) - if isinstance(text, str) and text.strip(): - return text - - output = getattr(resp, "output", None) - if isinstance(output, list): - for item in output: - if not isinstance(item, dict): - continue - if item.get("type") != "message": - continue - content = item.get("content") - if not isinstance(content, list): - continue - for c in content: - if not isinstance(c, dict): - continue - for key in ("text", "output_text", "value"): - v = c.get(key) - if isinstance(v, str) and v.strip(): - return v - - if not isinstance(text, str) or not text.strip(): - logger.warning("Responses API 응답에서 텍스트를 추출하지 못했습니다") - return "" - return text - - -async def request_responses_batch( - requests: List[dict], timeout: float -) -> List[Optional[str]]: - # 2. 개별 요청을 처리하는 내부 함수 - async def _one(req: dict) -> Optional[str]: - try: - text = await request_responses_output_text(req) - return text if text else None - except APITimeoutError: - logger.error("OpenAI API Timeout") - return None - except Exception: - logger.exception("Batch request 실패") - return None - - # 3. 태스크 생성 (여기서는 쓰레드가 아닌 가벼운 Coroutine이 생성됨) - tasks = [asyncio.create_task(_one(r)) for r in requests] - done, pending = await asyncio.wait(tasks, timeout=timeout) - - if pending: - logger.error( - f"Batch processing timed out after {timeout} seconds. {len(pending)} tasks incomplete." - ) - for task in pending: - task.cancel() # 대기 중인 코루틴 취소 - - # 5. 결과 수집 - results = [] - for task in tasks: - if task in done: - try: - # task.result()는 즉시 값을 반환함 (이미 완료되었으므로) - results.append(task.result()) - except asyncio.CancelledError: - results.append(None) - except Exception: - logger.exception("Task completed with exception") - results.append(None) - else: - results.append(None) - - return results diff --git a/app/adapter/request_single.py b/app/adapter/request_to_gpt.py similarity index 65% rename from app/adapter/request_single.py rename to app/adapter/request_to_gpt.py index cd5a1a4..8593279 100644 --- a/app/adapter/request_single.py +++ b/app/adapter/request_to_gpt.py @@ -1,28 +1,28 @@ -import asyncio +import os +from functools import lru_cache from fastapi import HTTPException -from openai import APITimeoutError +from openai import AsyncOpenAI, APITimeoutError -from app.client.oepn_ai import get_gpt_client from app.util.logger import logger -async def request_responses_output_text_async(gpt_request: dict, timeout: float) -> str: - """Responses API 요청을 비동기로 처리하며 타임아웃을 적용한다.""" - try: - return await asyncio.wait_for( - asyncio.to_thread(request_responses_output_text, gpt_request), - timeout=timeout, - ) - except asyncio.TimeoutError: - logger.error(f"Single request timed out after {timeout} seconds.") - raise HTTPException(status_code=408, detail="Request Timeout") +@lru_cache(maxsize=1) +def get_gpt_client() -> AsyncOpenAI: + """OpenAI 비동기 클라이언트를 캐싱하여(싱글톤처럼) 제공한다.""" + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise RuntimeError("OPENAI_API_KEY 환경변수가 설정되어 있지 않습니다.") + + return AsyncOpenAI(api_key=api_key, max_retries=0) -def request_responses_output_text(gpt_request: dict) -> str: +async def request_to_gpt_returning_text(gpt_request: dict, timeout: int) -> str: """Responses API로 단건 요청을 전송하고 텍스트만 추출한다.""" try: - resp = get_gpt_client().responses.create(**gpt_request) + client = get_gpt_client() + client = client.with_options(timeout=float(timeout)) + resp = await client.responses.create(**gpt_request) except APITimeoutError: logger.error("OpenAI API Timeout") raise HTTPException(status_code=429, detail="OpenAI API Timeout") diff --git a/app/client/__init__.py b/app/client/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/app/client/oepn_ai.py b/app/client/oepn_ai.py deleted file mode 100644 index 0d68235..0000000 --- a/app/client/oepn_ai.py +++ /dev/null @@ -1,18 +0,0 @@ -import os - -from openai import AsyncOpenAI - -_gpt_client: AsyncOpenAI | None = None - - -def get_gpt_client() -> AsyncOpenAI: - """OpenAI 비동기 클라이언트를 싱글톤으로 제공한다.""" - global _gpt_client - if _gpt_client is None: - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - raise RuntimeError("OPENAI_API_KEY 환경변수가 설정되어 있지 않습니다.") - - _gpt_client = AsyncOpenAI(api_key=api_key, max_retries=0) - - return _gpt_client diff --git a/app/dto/model/problem_set.py b/app/dto/model/problem_set.py index e1501af..14d006a 100644 --- a/app/dto/model/problem_set.py +++ b/app/dto/model/problem_set.py @@ -5,9 +5,7 @@ class Selection(BaseModel): content: str = Field(description="선택지 내용입니다.") - correct: bool = Field( - description="정답 여부입니다. 정답이면 True, 오답이면 False입니다." - ) + correct: bool = Field(description="정답 여부입니다. 정답이면 True, 오답이면 False입니다.") class Problem(BaseModel): diff --git a/app/dto/response/error_response.py b/app/dto/response/error_response.py new file mode 100644 index 0000000..c92c3e4 --- /dev/null +++ b/app/dto/response/error_response.py @@ -0,0 +1,9 @@ +from typing import Literal + +from pydantic import BaseModel + + +class ErrorResponse(BaseModel): + type: Literal["error"] = "error" + code: int + message: str diff --git a/app/dto/response/generate_response.py b/app/dto/response/generate_response.py index 44bbe61..0500fc8 100644 --- a/app/dto/response/generate_response.py +++ b/app/dto/response/generate_response.py @@ -1,11 +1,11 @@ -from typing import List +from typing import List, Literal from pydantic import BaseModel from app.dto.model.problem_set import Selection -class ProblemResponse(BaseModel): +class ProblemDTO(BaseModel): number: int title: str selections: List[Selection] @@ -13,5 +13,9 @@ class ProblemResponse(BaseModel): referencedPages: List[int] -class GenerateResponse(BaseModel): - quiz: List[ProblemResponse] +class ProblemSetDTO(BaseModel): + quiz: List[ProblemDTO] + + +class GenerateResponse(ProblemSetDTO): + type: Literal["quiz"] = "quiz" diff --git a/app/router/generate_router.py b/app/router/generate_router.py index a24464f..8e3f665 100644 --- a/app/router/generate_router.py +++ b/app/router/generate_router.py @@ -1,8 +1,8 @@ -from fastapi import APIRouter, Depends +from fastapi import APIRouter +from starlette.responses import StreamingResponse from app.dto.request.generate_request import GenerateRequest from app.dto.request.specific_explanation_request import SpecificExplanationRequest -from app.dto.response.generate_response import GenerateResponse from app.dto.response.specific_explanation_response import SpecificExplanationResponse from app.service.explanation_service import ExplanationService from app.service.generate_service import GenerateService @@ -10,24 +10,15 @@ router = APIRouter() -def get_generate_service(): - return GenerateService() - - -def get_explanation_service(): - return ExplanationService() - - @router.post("/generation") -async def generate( - request: GenerateRequest, generate_service=Depends(get_generate_service) -) -> GenerateResponse: - return await generate_service.generate(request) +async def generate(request: GenerateRequest) -> StreamingResponse: + return StreamingResponse( + GenerateService.generate(request), media_type="application/x-ndjson" + ) @router.post("/specific-explanation") async def generate_specific_explanation( request: SpecificExplanationRequest, - explanation_service=Depends(get_explanation_service), ) -> SpecificExplanationResponse: - return await explanation_service.generate_specific_explanation(request) + return await ExplanationService.generate_specific_explanation(request) diff --git a/app/service/explanation_service.py b/app/service/explanation_service.py index 5f140d7..1eec152 100644 --- a/app/service/explanation_service.py +++ b/app/service/explanation_service.py @@ -1,4 +1,4 @@ -from app.adapter.request_single import request_responses_output_text_async +from app.adapter.request_to_gpt import request_to_gpt_returning_text from app.dto.request.specific_explanation_request import SpecificExplanationRequest from app.dto.response.specific_explanation_response import SpecificExplanationResponse from app.util.logger import logger @@ -10,7 +10,6 @@ class ExplanationService: async def generate_specific_explanation( specific_explanation_request: SpecificExplanationRequest, ): - title = specific_explanation_request.title selections = specific_explanation_request.selections @@ -56,9 +55,7 @@ async def generate_specific_explanation( } with log_elapsed(logger, "request_specific_explanation_with_search"): - combined_text = await request_responses_output_text_async( - gpt_content, timeout=60 - ) + combined_text = await request_to_gpt_returning_text(gpt_content, timeout=30) combined_text = (combined_text or "").strip() references = [] diff --git a/app/service/generate_service.py b/app/service/generate_service.py index 7a1dbf2..3f8b8a1 100644 --- a/app/service/generate_service.py +++ b/app/service/generate_service.py @@ -1,175 +1,59 @@ +import asyncio import base64 import os import random from copy import deepcopy -from typing import Any, List +from typing import List, Optional from urllib.parse import urlparse import fitz import requests -from fastapi import HTTPException from langchain_core.output_parsers import JsonOutputParser -from app.adapter.request_batch import request_responses_batch -from app.dto.model.generated_result import GeneratedResult +from app.adapter.request_to_gpt import request_to_gpt_returning_text from app.dto.model.problem_set import ProblemSet -from app.dto.request.generate_request import GenerateRequest, QuizType +from app.dto.request.generate_request import GenerateRequest +from app.dto.response.error_response import ErrorResponse from app.dto.response.generate_response import ( - GenerateResponse, - ProblemResponse, + ProblemSetDTO, + ProblemDTO, GenerateResponse, ) from app.prompt import prompt_factory from app.util.create_chunks import create_page_chunks +from app.util.gpt_utils import enforce_additional_properties_false from app.util.logger import logger from app.util.rate_limiter import rate_limiter from app.util.timing import log_elapsed -def _enforce_additional_properties_false(schema: Any) -> Any: - """ - OpenAI Structured Outputs(json_schema, strict=True) 제약: - object 스키마는 additionalProperties=false 가 필요하다. - Pydantic이 생성한 JSON Schema에 이를 재귀적으로 주입한다. - """ - if isinstance(schema, list): - return [_enforce_additional_properties_false(s) for s in schema] - - if not isinstance(schema, dict): - return schema - - # object 타입이면 additionalProperties를 명시적으로 false로 고정 - if schema.get("type") == "object" and "additionalProperties" not in schema: - schema["additionalProperties"] = False - - # 자주 등장하는 하위 스키마 컨테이너들 재귀 순회 - dict_children_keys = ("properties", "$defs", "definitions") - for key in dict_children_keys: - child = schema.get(key) - if isinstance(child, dict): - for k, v in child.items(): - child[k] = _enforce_additional_properties_false(v) - - list_children_keys = ("anyOf", "oneOf", "allOf", "prefixItems") - for key in list_children_keys: - child = schema.get(key) - if isinstance(child, list): - schema[key] = [_enforce_additional_properties_false(v) for v in child] - - # 단일 하위 스키마들 - for key in ("items", "not", "if", "then", "else"): - child = schema.get(key) - if isinstance(child, (dict, list)): - schema[key] = _enforce_additional_properties_false(child) - - # additionalProperties가 dict로 오는 케이스도 대비(여기선 false로 고정하는 게 목적이지만, 안전하게 재귀 처리) - ap = schema.get("additionalProperties") - if isinstance(ap, dict): - schema["additionalProperties"] = _enforce_additional_properties_false(ap) - - return schema - - -def _extract_filename(uploaded_url: str) -> str: - parsed = urlparse(uploaded_url) - filename = os.path.basename(parsed.path) - return filename or "document.pdf" - - -def _load_pdf_content(uploaded_url: str) -> bytes: - response = requests.get(uploaded_url) - response.raise_for_status() - return response.content - - -def _get_pdf_page_count(file_content: bytes) -> int: - pdf_documents = fitz.open(stream=file_content, filetype="pdf") - page_count = len(pdf_documents) - pdf_documents.close() - return page_count - - -def _extract_pdf_pages_base64(file_content: bytes, pages: List[int]) -> str: - source = fitz.open(stream=file_content, filetype="pdf") - target = fitz.open() - for page_number in pages: - page_index = page_number - 1 - if 0 <= page_index < len(source): - target.insert_pdf(source, from_page=page_index, to_page=page_index) - pdf_bytes = target.tobytes() - target.close() - source.close() - return base64.b64encode(pdf_bytes).decode("ascii") - - class GenerateService: @staticmethod async def generate(generate_request: GenerateRequest): - quiz_count = generate_request.quizCount - uploaded_url = generate_request.uploadedUrl total_quiz_count = generate_request.quizCount - dok_level = generate_request.difficultyType - quiz_type = generate_request.quizType page_numbers = generate_request.pageNumbers - pdf_bytes = _load_pdf_content(uploaded_url) - page_count = _get_pdf_page_count(pdf_bytes) - - pdf_bytes = _load_pdf_content(uploaded_url) - page_count = _get_pdf_page_count(pdf_bytes) - - selected_pages = page_numbers - if not selected_pages: - selected_pages = list(range(1, page_count + 1)) - else: - selected_pages = [p for p in selected_pages if 0 < p <= page_count] - - texts = [""] * (len(selected_pages) + 1) - - max_chunk_count = 15 - chunks = create_page_chunks(len(texts) - 1, total_quiz_count, max_chunk_count) + chunks = create_page_chunks( + page_numbers, total_quiz_count, int(os.environ["MAX_CHUNK_COUNT"]) + ) await rate_limiter.check_rate(len(chunks)) - i = 0 - while i < len(chunks): - chunk = chunks[i] - - while chunk.quiz_count > 2: - # 기존 chunk 복제 - new_chunk = deepcopy(chunk) - new_chunk.quiz_count = 2 - # 현재 인덱스 i 앞에 삽입 - chunks.insert(i, new_chunk) - # 원본 chunk의 count 감소 - chunk.quiz_count -= 2 - # 삽입된 만큼 한 칸 이동 - i += 1 - - i += 1 - - page_index_source = selected_pages - for chunk in chunks: - chunk.referenced_pages = [ - page_index_source[i - 1] - for i in chunk.referenced_pages - if 1 <= i <= len(page_index_source) - ] - - parser = JsonOutputParser(pydantic_object=ProblemSet) - problem_set_json_schema = _enforce_additional_properties_false( + problem_set_json_schema = enforce_additional_properties_false( deepcopy(ProblemSet.model_json_schema()) ) - gpt_contents = [] + dok_level = generate_request.difficultyType + quiz_type = generate_request.quizType + uploaded_url = generate_request.uploadedUrl pdf_chunk_cache: dict[tuple[int, ...], str] = {} - - for chunk in chunks: + pdf_bytes = _load_pdf_content(uploaded_url) + for i, chunk in enumerate(chunks): system_message = f""" 당신은 대학 강의노트로부터 평가용 퀴즈를 생성하는 AI입니다. 주어진 강의노트 내용을 분석하여 학생들의 이해도를 평가할 수 있는 효과적인 퀴즈를 정확히 {chunk.quiz_count}개 생성하세요. 작성 규칙: - 한국어로 작성 - - 마크다운을 활용해 가독성을 높힌다 + - 적극적으로 개행하여 가독성에 신경 쓸 것 - 강의 노트를 참조하라는 문제 생성 금지 (예: "강의노트에 따르면", "본문을 참고하면" 등 금지) 문제 생성 지침(품질/난이도): @@ -179,121 +63,161 @@ async def generate(generate_request: GenerateRequest): {prompt_factory.get_quiz_format(quiz_type)} """.strip() - page_hint = ( - f"참조 페이지: {', '.join(map(str, chunk.referenced_pages))}" - if chunk.referenced_pages - else "참조 페이지: 없음" - ) pages_key = tuple(chunk.referenced_pages) if pages_key not in pdf_chunk_cache: pdf_chunk_cache[pages_key] = _extract_pdf_pages_base64( pdf_bytes, chunk.referenced_pages ) pdf_chunk_base64 = pdf_chunk_cache[pages_key] - gpt_contents.append( - { - "model": "gpt-4.1-mini", - "max_output_tokens": 10000, - "text": { - "format": { - "type": "json_schema", - "name": "problem_set", - "strict": True, - "schema": problem_set_json_schema, - } + if i < max(len(chunks) * 0.2, 3): + model = "gpt-4.1-mini" + else: + model = "gpt-5-mini" + chunk.gpt_content = { + "model": model, + "max_output_tokens": 10000, + "text": { + "format": { + "type": "json_schema", + "name": "problem_set", + "strict": True, + "schema": problem_set_json_schema, + } + }, + "input": [ + {"role": "system", "content": system_message}, + { + "role": "user", + "content": [ + { + "type": "input_text", + "text": f"# 강의노트(PDF)", + }, + { + "type": "input_file", + "filename": _extract_filename(uploaded_url), + "file_data": f"data:application/pdf;base64,{pdf_chunk_base64}", + }, + ], }, - "input": [ - {"role": "system", "content": system_message}, - { - "role": "user", - "content": [ - { - "type": "input_text", - "text": f"# 강의노트(PDF)\n\n{page_hint}", - }, - { - "type": "input_file", - "filename": _extract_filename(uploaded_url), - "file_data": f"data:application/pdf;base64,{pdf_chunk_base64}", - }, - ], - }, - ], - } - ) + ], + } - with log_elapsed(logger, "request_generate_quiz"): - timeout = int(os.environ["GPT_REQUEST_TIMEOUT"]) - texts = await request_responses_batch(gpt_contents, timeout=timeout) - generated_results: List[GeneratedResult] = [] - for sequence, text in enumerate(texts, start=1): - if not text: - continue - generated_results.append( - GeneratedResult(sequence=sequence, generated_text=text) + tasks = [] + for chunk in chunks: + tasks.append( + asyncio.create_task( + process_single_chunk( + chunk.gpt_content, + JsonOutputParser(pydantic_object=ProblemSet), + chunk.referenced_pages, + quiz_type, + ) ) + ) - if not generated_results: - raise HTTPException( - status_code=429, - detail="모든 퀴즈 생성 요청이 실패하거나 시간 초과되었습니다.", - ) + # 스트리밍 응답 처리 + number = 1 + try: + for completed_task in asyncio.as_completed(tasks): + try: + result: Optional[ProblemSetDTO] = await completed_task + + if result: + for quiz in result.quiz: + quiz.number = number + number += 1 + yield result.model_dump_json() + "\n" + + except Exception as e: + logger.error(f"Task processing failed: {e}") + status_code = getattr(e, "status_code", 500) + yield ErrorResponse( + code=status_code, message=str(e) + ).model_dump_json() + "\n" + + except Exception as e: + logger.error(f"Critical streaming error: {e}") + yield ErrorResponse(code=500, message=str(e)).model_dump_json() + "\n" + + +async def process_single_chunk( + gpt_request: dict, + parser: JsonOutputParser, + referenced_pages: List[int], + quiz_type: str, +) -> Optional[ProblemSetDTO]: + with log_elapsed(logger, "request_generate_quiz"): + try: + text_response = await request_to_gpt_returning_text( + gpt_request, os.environ["GPT_REQUEST_TIMEOUT"] + ) - sorted_responses = [] - for i, generated_result in enumerate(generated_results): - try: - generated_text = parser.parse(generated_result.generated_text) + if not text_response: + return None - # 방어: 첫 문제 선택지가 4개 초과면 폐기 - quiz = ( - generated_text.get("quiz") - if isinstance(generated_text, dict) - else None - ) - if isinstance(quiz, list) and len(quiz) > 0: - selections = ( - quiz[0].get("selections") if isinstance(quiz[0], dict) else None - ) - if isinstance(selections, list) and len(selections) > 4: - continue + # 파싱 + generated_data = parser.parse(text_response) + + # 구조 검증 및 변환 + quiz_list = generated_data.get("quiz", []) + if not quiz_list: + return None + + if not isinstance(quiz_list, list) or len(quiz_list) == 0: + return None + + if len(quiz_list) > 0: + first_selections = quiz_list[0].get("selections") + if isinstance(first_selections, list) and len(first_selections) > 4: + return None + + # 문제 변환 (DTO 매핑) + problem_responses = [] + for q in quiz_list: + # 선택지 셔플 등 로직 수행 + selections = q.get("selections", []) + if quiz_type in ["MULTIPLE", "BLANK"] and selections: + random.shuffle(selections) - sorted_responses.append( - { - "sequence": generated_result.sequence, - "generated_text": generated_text, - } - ) - except Exception as e: - logger.error(f"Parsing error for response {i}: {e}") - logger.error(f"Response content: {generated_result.generated_text}") - continue - - sorted_responses.sort(key=lambda x: x["sequence"]) - - seq_to_pages = {i + 1: chunk.referenced_pages for i, chunk in enumerate(chunks)} - - problem_responses: List[ProblemResponse] = [] - for generated_result in sorted_responses: - quiz_data = generated_result.get("generated_text") - quiz = quiz_data.get("quiz") - for problem in quiz: - if ( - quiz_type == QuizType.MULTIPLE.value - or quiz_type == QuizType.BLANK.value - ): - random.shuffle(problem.get("selections")) problem_responses.append( - ProblemResponse( - **problem, - referencedPages=seq_to_pages.get( - generated_result["sequence"], [] - ), + ProblemDTO( + number=0, + title=q.get("title"), + selections=selections, + explanation=q.get("explanation"), + referencedPages=referenced_pages, ) ) - real_sequence_number = 1 - for i, problem in enumerate(problem_responses): - problem.number = real_sequence_number - real_sequence_number += 1 + # 1~2개의 문제가 담긴 부분 응답 객체 반환 + return GenerateResponse(quiz=problem_responses) + + except Exception as e: + logger.error(f"Chunk processing error: {e}") + raise e - return GenerateResponse(quiz=problem_responses) + +def _extract_filename(uploaded_url: str) -> str: + parsed = urlparse(uploaded_url) + filename = os.path.basename(parsed.path) + return filename or "document.pdf" + + +def _load_pdf_content(uploaded_url: str) -> bytes: + response = requests.get(uploaded_url) + response.raise_for_status() + return response.content + + +def _extract_pdf_pages_base64(file_content: bytes, pages: List[int]) -> str: + source = fitz.open(stream=file_content, filetype="pdf") + target = fitz.open() + for page_number in pages: + page_index = page_number - 1 + if 0 <= page_index < len(source): + target.insert_pdf(source, from_page=page_index, to_page=page_index) + pdf_bytes = target.tobytes() + target.close() + source.close() + return base64.b64encode(pdf_bytes).decode("ascii") diff --git a/app/util/create_chunks.py b/app/util/create_chunks.py index a0cfb69..edff522 100644 --- a/app/util/create_chunks.py +++ b/app/util/create_chunks.py @@ -1,113 +1,50 @@ -import math from typing import List from pydantic.v1 import BaseModel +from typing_extensions import Dict class ChunkInfo(BaseModel): referenced_pages: List[int] quiz_count: int + gpt_content: Dict = {} def create_page_chunks( - page_count: int, + page_numbers: List[int], total_quiz_count: int, max_chunk_count: int, ) -> List[ChunkInfo]: - if total_quiz_count < page_count: - chunks = handle_quiz_smaller_than_total_page(total_quiz_count, page_count) - elif total_quiz_count > page_count: - chunks = handle_quiz_larger_than_total_page(total_quiz_count, page_count) - else: - chunks = handle_quiz_same_as_total_page(total_quiz_count, page_count) - - return compress_chunks(max_chunk_count, chunks) - - -def compress_chunks(max_chunk_count: int, chunks: List[ChunkInfo]) -> List[ChunkInfo]: - if len(chunks) <= max_chunk_count: - return chunks - - base_size = len(chunks) // max_chunk_count - remainder = len(chunks) % max_chunk_count - - compressed_chunks = [] - idx = 0 - for i in range(max_chunk_count): - - group_size = base_size + 1 if i < remainder else base_size - - merged_chunk = ChunkInfo(referenced_pages=[], quiz_count=0) - index_set = set([]) - for _ in range(group_size): - chunk = chunks[idx] - for referenced_page in chunk.referenced_pages: - index_set.add(referenced_page) - idx += 1 - merged_chunk.quiz_count += chunk.quiz_count - - merged_chunk.referenced_pages = list(sorted(index_set)) - compressed_chunks.append(merged_chunk) - - return compressed_chunks - - -def handle_quiz_smaller_than_total_page( - total_quiz_count: int, page_count: int -) -> List[ChunkInfo]: - chunks = [] - page_per_quiz = page_count / total_quiz_count - - for quiz_sequence in range(total_quiz_count): - start = math.floor(quiz_sequence * page_per_quiz) + 1 - end = math.floor((quiz_sequence + 1) * page_per_quiz) - - referenced_pages = list(range(start, end + 1)) - - chunks.append(ChunkInfo(referenced_pages=referenced_pages, quiz_count=1)) - - return chunks - - -def handle_quiz_larger_than_total_page( - total_quiz_count: int, page_count: int -) -> List[ChunkInfo]: - - chunks = [] - page_per_quiz = page_count / total_quiz_count - - quiz_counts = [0] * (page_count + 1) - - for k in range(total_quiz_count): - page_idx = math.floor(k * page_per_quiz) + 1 - if page_idx > page_count: - page_idx = page_count - quiz_counts[page_idx] += 1 - - for i in range(1, page_count + 1): - chunk_info = ChunkInfo( - referenced_pages=[i], - quiz_count=quiz_counts[i], - ) - chunks.append(chunk_info) - - return chunks - -def handle_quiz_same_as_total_page( - total_quiz_count: int, page_count: int -) -> List[ChunkInfo]: - - chunks = [] - page_per_quiz = page_count / total_quiz_count - - for quiz_sequence in range(total_quiz_count): - cur_page_index = math.floor(quiz_sequence * page_per_quiz) + 1 - - chunks.append( - ChunkInfo( - referenced_pages=[cur_page_index], - quiz_count=1, - ) - ) + # 청크 별 퀴즈 개수 분배 + chunks: List[ChunkInfo] = [] + for i in range(total_quiz_count): + if i // max_chunk_count == 0: + chunks.append(ChunkInfo(quiz_count=1, referenced_pages=[])) + else: + chunks[i % max_chunk_count].quiz_count += 1 + + # 각 청크에 페이지 할당 + real_chunk_count = len(chunks) + page_count = len(page_numbers) + basic_page_count_per_chunk = page_count // real_chunk_count + extra_pages = page_count % real_chunk_count + cur = 0 + for chunk in chunks: + pages_for_this_chunk = basic_page_count_per_chunk + if extra_pages > 0: + pages_for_this_chunk += 1 + extra_pages -= 1 + + # 앞뒤로 한 페이지씩 여유를 둔다. + if pages_for_this_chunk < 3: + if cur == 0: + chunk.referenced_pages = page_numbers[0:3] + elif cur == len(page_numbers) - 1: + chunk.referenced_pages = page_numbers[-3:] + else: + chunk.referenced_pages = page_numbers[cur - 1 : cur + 2] + else: + chunk.referenced_pages = page_numbers[cur : cur + pages_for_this_chunk] + cur += pages_for_this_chunk return chunks diff --git a/app/util/gpt_utils.py b/app/util/gpt_utils.py new file mode 100644 index 0000000..2b287d6 --- /dev/null +++ b/app/util/gpt_utils.py @@ -0,0 +1,45 @@ +from typing import Any + + +def enforce_additional_properties_false(schema: Any) -> Any: + """ + OpenAI Structured Outputs(json_schema, strict=True) 제약: + object 스키마는 additionalProperties=false 가 필요하다. + Pydantic이 생성한 JSON Schema에 이를 재귀적으로 주입한다. + """ + if isinstance(schema, list): + return [enforce_additional_properties_false(s) for s in schema] + + if not isinstance(schema, dict): + return schema + + # object 타입이면 additionalProperties를 명시적으로 false로 고정 + if schema.get("type") == "object" and "additionalProperties" not in schema: + schema["additionalProperties"] = False + + # 자주 등장하는 하위 스키마 컨테이너들 재귀 순회 + dict_children_keys = ("properties", "$defs", "definitions") + for key in dict_children_keys: + child = schema.get(key) + if isinstance(child, dict): + for k, v in child.items(): + child[k] = enforce_additional_properties_false(v) + + list_children_keys = ("anyOf", "oneOf", "allOf", "prefixItems") + for key in list_children_keys: + child = schema.get(key) + if isinstance(child, list): + schema[key] = [enforce_additional_properties_false(v) for v in child] + + # 단일 하위 스키마들 + for key in ("items", "not", "if", "then", "else"): + child = schema.get(key) + if isinstance(child, (dict, list)): + schema[key] = enforce_additional_properties_false(child) + + # additionalProperties가 dict로 오는 케이스도 대비(여기선 false로 고정하는 게 목적이지만, 안전하게 재귀 처리) + ap = schema.get("additionalProperties") + if isinstance(ap, dict): + schema["additionalProperties"] = enforce_additional_properties_false(ap) + + return schema