diff --git a/nilai/db.py b/nilai/db.py index f39d6e12..539aad25 100644 --- a/nilai/db.py +++ b/nilai/db.py @@ -55,6 +55,7 @@ class User(Base): apikey = Column(String(36), unique=True, nullable=False, index=True) prompt_tokens = Column(Integer, default=0, nullable=False) completion_tokens = Column(Integer, default=0, nullable=False) + queries = Column(Integer, default=0, nullable=False) def __repr__(self): return f"" @@ -67,6 +68,7 @@ class UserData: apikey: str input_tokens: int generated_tokens: int + queries: int # Context manager for database sessions @@ -180,6 +182,7 @@ def update_token_usage(userid: str, prompt_tokens: int, completion_tokens: int): if user: user.prompt_tokens += prompt_tokens # type: ignore user.completion_tokens += completion_tokens # type: ignore + user.queries += 1 # type: ignore logger.info(f"Updated token usage for user {userid}") else: logger.warning(f"User {userid} not found") @@ -206,6 +209,7 @@ def get_token_usage( "prompt_tokens": user.prompt_tokens, "completion_tokens": user.completion_tokens, "total_tokens": user.prompt_tokens + user.completion_tokens, + "queries": user.queries, } else: logger.warning(f"User {userid} not found") @@ -230,6 +234,7 @@ def get_all_users() -> Optional[List[UserData]]: apikey=user.apikey, # type: ignore input_tokens=user.prompt_tokens, # type: ignore generated_tokens=user.completion_tokens, # type: ignore + queries=user.queries, # type: ignore ) for user in users ] @@ -255,6 +260,7 @@ def get_user_token_usage(userid: str) -> Optional[Dict[str, int]]: return { "prompt_tokens": user.prompt_tokens, "completion_tokens": user.completion_tokens, + "queries": user.queries, } # type: ignore return None except SQLAlchemyError as e: