Skip to content

LLM 토큰 사용량 집계 로직 TokenUtils 클래스로 분리 #117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 24 additions & 26 deletions interface/lang2sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
from langchain_core.messages import AIMessage, HumanMessage

from llm_utils.connect_db import ConnectDB
from llm_utils.graph import builder
from llm_utils.enriched_graph import builder as enriched_builder
from llm_utils.display_chart import DisplayChart
from llm_utils.enriched_graph import builder as enriched_builder
from llm_utils.graph import builder
from llm_utils.llm_response_parser import LLMResponseParser
from llm_utils.token_utils import TokenUtils

TITLE = "Lang2SQL"
DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
SIDEBAR_OPTIONS = {
"show_total_token_usage": "Show Total Token Usage",
"show_token_usage": "Show Token Usage",
"show_result_description": "Show Result Description",
"show_sql": "Show SQL",
"show_question_reinterpreted_by_ai": "Show User Question Reinterpreted by AI",
Expand All @@ -27,24 +29,6 @@
}


def summarize_total_tokens(data: list) -> int:
"""
메시지 데이터에서 총 토큰 사용량을 집계합니다.

Args:
data (list): usage_metadata를 포함하는 객체들의 리스트.

Returns:
int: 총 토큰 사용량 합계.
"""

total_tokens = 0
for item in data:
token_usage = getattr(item, "usage_metadata", {})
total_tokens += token_usage.get("total_tokens", 0)
return total_tokens


def execute_query(
*,
query: str,
Expand Down Expand Up @@ -119,14 +103,22 @@ def display_result(
"""

def should_show(_key: str) -> bool:
st.markdown("---")
return st.session_state.get(_key, True)

if should_show("show_total_token_usage"):
total_tokens = summarize_total_tokens(res["messages"])
st.write("**총 토큰 사용량:**", total_tokens)
if should_show("show_token_usage"):
st.markdown("---")
token_summary = TokenUtils.get_token_usage_summary(data=res["messages"])
st.write("**토큰 사용량:**")
st.markdown(
f"""
- Input tokens: `{token_summary['input_tokens']}`
- Output tokens: `{token_summary['output_tokens']}`
- Total tokens: `{token_summary['total_tokens']}`
"""
)

if should_show("show_sql"):
st.markdown("---")
generated_query = res.get("generated_query")
query_text = (
generated_query.content
Expand All @@ -148,6 +140,7 @@ def should_show(_key: str) -> bool:
st.code(interpretation)

if should_show("show_result_description"):
st.markdown("---")
st.markdown("**결과 설명:**")
result_message = res["messages"][-1].content

Expand All @@ -163,14 +156,17 @@ def should_show(_key: str) -> bool:
st.code(interpretation, language="plaintext")

if should_show("show_question_reinterpreted_by_ai"):
st.markdown("---")
st.markdown("**AI가 재해석한 사용자 질문:**")
st.code(res["refined_input"].content)

if should_show("show_referenced_tables"):
st.markdown("---")
st.markdown("**참고한 테이블 목록:**")
st.write(res.get("searched_tables", []))

if should_show("show_table"):
st.markdown("---")
try:
sql_raw = (
res["generated_query"].content
Expand All @@ -182,7 +178,9 @@ def should_show(_key: str) -> bool:
st.dataframe(df.head(10) if len(df) > 10 else df)
except Exception as e:
st.error(f"쿼리 실행 중 오류 발생: {e}")

if should_show("show_chart"):
st.markdown("---")
df = database.run_sql(sql)
st.markdown("**쿼리 결과 시각화:**")
display_code = DisplayChart(
Expand All @@ -199,7 +197,7 @@ def should_show(_key: str) -> bool:

db = ConnectDB()

st.title("Lang2SQL")
st.title(TITLE)

# 워크플로우 선택(UI)
use_enriched = st.sidebar.checkbox(
Expand Down
93 changes: 93 additions & 0 deletions llm_utils/token_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
token_utils.py

LLM 응답 메시지에서 토큰 사용량을 집계하기 위한 유틸리티 모듈입니다.

이 모듈은 LLM의 `usage_metadata` 필드를 기반으로 입력 토큰, 출력 토큰, 총 토큰 사용량을 계산하는 기능을 제공합니다.
Streamlit, LangChain 등 LLM 응답을 다루는 애플리케이션에서 비용 분석, 사용량 추적 등에 활용할 수 있습니다.
"""

import logging
from typing import Any, List

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)


class TokenUtils:
"""
LLM 토큰 사용량 집계 유틸리티 클래스입니다.

이 클래스는 LLM 응답 메시지 리스트에서 usage_metadata 필드를 추출하여
input_tokens, output_tokens, total_tokens의 합계를 계산합니다.

예를 들어, LangChain 또는 OpenAI API 응답 메시지 객체의 토큰 사용 정보를 분석하고자 할 때
활용할 수 있습니다.

사용 예:
>>> from token_utils import TokenUtils
>>> summary = TokenUtils.get_token_usage_summary(messages)
>>> print(summary["total_tokens"])

반환 형식:
{
"input_tokens": int,
"output_tokens": int,
"total_tokens": int,
}
"""

@staticmethod
def get_token_usage_summary(*, data: List[Any]) -> dict:
"""
메시지 데이터에서 input/output/total 토큰 사용량을 각각 집계합니다.

Args:
data (List[Any]): 각 항목이 usage_metadata를 포함할 수 있는 객체 리스트.

Returns:
dict: {
"input_tokens": int,
"output_tokens": int,
"total_tokens": int
}
"""

input_tokens = 0
output_tokens = 0
total_tokens = 0

for idx, item in enumerate(data):
token_usage = getattr(item, "usage_metadata", {})
in_tok = token_usage.get("input_tokens", 0)
out_tok = token_usage.get("output_tokens", 0)
total_tok = token_usage.get("total_tokens", 0)

logger.debug(
"Message[%d] → input=%d, output=%d, total=%d",
idx,
in_tok,
out_tok,
total_tok,
)

input_tokens += in_tok
output_tokens += out_tok
total_tokens += total_tok

logger.info(
"Token usage summary → input: %d, output: %d, total: %d",
input_tokens,
output_tokens,
total_tokens,
)

return {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
}