From ead12d50682a93bb5aff7242b742f86a49f03ef2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E1=84=87=E1=85=A1=E1=86=A8=E1=84=80=E1=85=A7=E1=86=BC?= =?UTF-8?q?=E1=84=90=E1=85=A2?= Date: Sat, 24 May 2025 12:38:36 +0900 Subject: [PATCH] refactor: extract TokenUtils class for LLM token usage summarization --- interface/lang2sql.py | 50 +++++++++++---------- llm_utils/token_utils.py | 93 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 26 deletions(-) create mode 100644 llm_utils/token_utils.py diff --git a/interface/lang2sql.py b/interface/lang2sql.py index a41068f..9d94268 100644 --- a/interface/lang2sql.py +++ b/interface/lang2sql.py @@ -10,14 +10,16 @@ from langchain_core.messages import AIMessage, HumanMessage from llm_utils.connect_db import ConnectDB -from llm_utils.graph import builder -from llm_utils.enriched_graph import builder as enriched_builder from llm_utils.display_chart import DisplayChart +from llm_utils.enriched_graph import builder as enriched_builder +from llm_utils.graph import builder from llm_utils.llm_response_parser import LLMResponseParser +from llm_utils.token_utils import TokenUtils +TITLE = "Lang2SQL" DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" SIDEBAR_OPTIONS = { - "show_total_token_usage": "Show Total Token Usage", + "show_token_usage": "Show Token Usage", "show_result_description": "Show Result Description", "show_sql": "Show SQL", "show_question_reinterpreted_by_ai": "Show User Question Reinterpreted by AI", @@ -27,24 +29,6 @@ } -def summarize_total_tokens(data: list) -> int: - """ - 메시지 데이터에서 총 토큰 사용량을 집계합니다. - - Args: - data (list): usage_metadata를 포함하는 객체들의 리스트. - - Returns: - int: 총 토큰 사용량 합계. - """ - - total_tokens = 0 - for item in data: - token_usage = getattr(item, "usage_metadata", {}) - total_tokens += token_usage.get("total_tokens", 0) - return total_tokens - - def execute_query( *, query: str, @@ -119,14 +103,22 @@ def display_result( """ def should_show(_key: str) -> bool: - st.markdown("---") return st.session_state.get(_key, True) - if should_show("show_total_token_usage"): - total_tokens = summarize_total_tokens(res["messages"]) - st.write("**총 토큰 사용량:**", total_tokens) + if should_show("show_token_usage"): + st.markdown("---") + token_summary = TokenUtils.get_token_usage_summary(data=res["messages"]) + st.write("**토큰 사용량:**") + st.markdown( + f""" + - Input tokens: `{token_summary['input_tokens']}` + - Output tokens: `{token_summary['output_tokens']}` + - Total tokens: `{token_summary['total_tokens']}` + """ + ) if should_show("show_sql"): + st.markdown("---") generated_query = res.get("generated_query") query_text = ( generated_query.content @@ -148,6 +140,7 @@ def should_show(_key: str) -> bool: st.code(interpretation) if should_show("show_result_description"): + st.markdown("---") st.markdown("**결과 설명:**") result_message = res["messages"][-1].content @@ -163,14 +156,17 @@ def should_show(_key: str) -> bool: st.code(interpretation, language="plaintext") if should_show("show_question_reinterpreted_by_ai"): + st.markdown("---") st.markdown("**AI가 재해석한 사용자 질문:**") st.code(res["refined_input"].content) if should_show("show_referenced_tables"): + st.markdown("---") st.markdown("**참고한 테이블 목록:**") st.write(res.get("searched_tables", [])) if should_show("show_table"): + st.markdown("---") try: sql_raw = ( res["generated_query"].content @@ -182,7 +178,9 @@ def should_show(_key: str) -> bool: st.dataframe(df.head(10) if len(df) > 10 else df) except Exception as e: st.error(f"쿼리 실행 중 오류 발생: {e}") + if should_show("show_chart"): + st.markdown("---") df = database.run_sql(sql) st.markdown("**쿼리 결과 시각화:**") display_code = DisplayChart( @@ -199,7 +197,7 @@ def should_show(_key: str) -> bool: db = ConnectDB() -st.title("Lang2SQL") +st.title(TITLE) # 워크플로우 선택(UI) use_enriched = st.sidebar.checkbox( diff --git a/llm_utils/token_utils.py b/llm_utils/token_utils.py new file mode 100644 index 0000000..5aef3e0 --- /dev/null +++ b/llm_utils/token_utils.py @@ -0,0 +1,93 @@ +""" +token_utils.py + +LLM 응답 메시지에서 토큰 사용량을 집계하기 위한 유틸리티 모듈입니다. + +이 모듈은 LLM의 `usage_metadata` 필드를 기반으로 입력 토큰, 출력 토큰, 총 토큰 사용량을 계산하는 기능을 제공합니다. +Streamlit, LangChain 등 LLM 응답을 다루는 애플리케이션에서 비용 분석, 사용량 추적 등에 활용할 수 있습니다. +""" + +import logging +from typing import Any, List + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger(__name__) + + +class TokenUtils: + """ + LLM 토큰 사용량 집계 유틸리티 클래스입니다. + + 이 클래스는 LLM 응답 메시지 리스트에서 usage_metadata 필드를 추출하여 + input_tokens, output_tokens, total_tokens의 합계를 계산합니다. + + 예를 들어, LangChain 또는 OpenAI API 응답 메시지 객체의 토큰 사용 정보를 분석하고자 할 때 + 활용할 수 있습니다. + + 사용 예: + >>> from token_utils import TokenUtils + >>> summary = TokenUtils.get_token_usage_summary(messages) + >>> print(summary["total_tokens"]) + + 반환 형식: + { + "input_tokens": int, + "output_tokens": int, + "total_tokens": int, + } + """ + + @staticmethod + def get_token_usage_summary(*, data: List[Any]) -> dict: + """ + 메시지 데이터에서 input/output/total 토큰 사용량을 각각 집계합니다. + + Args: + data (List[Any]): 각 항목이 usage_metadata를 포함할 수 있는 객체 리스트. + + Returns: + dict: { + "input_tokens": int, + "output_tokens": int, + "total_tokens": int + } + """ + + input_tokens = 0 + output_tokens = 0 + total_tokens = 0 + + for idx, item in enumerate(data): + token_usage = getattr(item, "usage_metadata", {}) + in_tok = token_usage.get("input_tokens", 0) + out_tok = token_usage.get("output_tokens", 0) + total_tok = token_usage.get("total_tokens", 0) + + logger.debug( + "Message[%d] → input=%d, output=%d, total=%d", + idx, + in_tok, + out_tok, + total_tok, + ) + + input_tokens += in_tok + output_tokens += out_tok + total_tokens += total_tok + + logger.info( + "Token usage summary → input: %d, output: %d, total: %d", + input_tokens, + output_tokens, + total_tokens, + ) + + return { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": total_tokens, + }