diff --git a/nilai-api/examples/users.py b/nilai-api/examples/users.py index 3ec30da9..9fad88d9 100644 --- a/nilai-api/examples/users.py +++ b/nilai-api/examples/users.py @@ -14,19 +14,19 @@ async def main(): print(f"Alice's details: {alice}") # Check API key - user_name = await UserManager.check_api_key(bob["apikey"]) + user_name = await UserManager.check_api_key(bob.apikey) print(f"API key validation: {user_name}") # Update and retrieve token usage await UserManager.update_token_usage( - bob["userid"], prompt_tokens=50, completion_tokens=20 + bob.userid, prompt_tokens=50, completion_tokens=20 ) - usage = await UserManager.get_user_token_usage(bob["userid"]) + usage = await UserManager.get_user_token_usage(bob.userid) print(f"Bob's token usage: {usage}") # Log a query await QueryLogManager.log_query( - userid=bob["userid"], + userid=bob.userid, model="gpt-3.5-turbo", prompt_tokens=8, completion_tokens=7, diff --git a/nilai-api/src/nilai_api/auth/jwt.py b/nilai-api/src/nilai_api/auth/jwt.py index 3e15f3e3..b0e796b8 100644 --- a/nilai-api/src/nilai_api/auth/jwt.py +++ b/nilai-api/src/nilai_api/auth/jwt.py @@ -58,7 +58,7 @@ def serialize_sign_doc(sign_doc: dict) -> bytes: def keplr_validate( - message: str, header: dict, payload: dict, signature: str + message: str, header: dict, payload: dict, signature: bytes ) -> JWTAuthResult: # Validate the algorithm if header["alg"] != "ES256": @@ -95,13 +95,15 @@ def keplr_validate( serialized_sign_doc, ) - return JWTAuthResult( - pub_key=payload.get("pub_key"), user_address=payload.get("user_address") - ) + pub_key = payload.get("pub_key") + user_address = payload.get("user_address") + if not pub_key or not user_address: + raise ValueError("Invalid payload, missing pub_key or user_address") + return JWTAuthResult(pub_key=pub_key, user_address=user_address) def metamask_validate( - message: str, header: dict, payload: dict, signature: str + message: str, header: dict, payload: dict, signature: bytes ) -> JWTAuthResult: # Validate the algorithm if header["alg"] != "ES256K": @@ -110,20 +112,23 @@ def metamask_validate( if payload.get("exp") and payload["exp"] < int(time.time()): raise ValueError("Token has expired") w3 = Web3(Web3.HTTPProvider("")) - message = encode_defunct(text=message) + signable_message = encode_defunct(text=message) address = w3.eth.account.recover_message( - message, signature=HexBytes("0x" + signature.hex()) + signable_message, signature=HexBytes("0x" + signature.hex()) ) if address.lower() != payload.get("user_address"): raise ValueError("Invalid signature") - return JWTAuthResult( - pub_key=payload.get("pub_key"), user_address=payload.get("user_address") - ) + pub_key = payload.get("pub_key") + user_address = payload.get("user_address") + if not pub_key or not user_address: + raise ValueError("Invalid payload, missing pub_key or user_address") + + return JWTAuthResult(pub_key=pub_key, user_address=user_address) -def extract_fields(jwt: str) -> tuple[str, dict, dict, str]: +def extract_fields(jwt: str) -> tuple[str, dict, dict, bytes]: # Split and decode JWT components header_b64, payload_b64, signature_b64 = jwt.split(".") if not all([header_b64, payload_b64, signature_b64]): diff --git a/nilai-api/src/nilai_api/db/logs.py b/nilai-api/src/nilai_api/db/logs.py index d18dd411..c4e4b854 100644 --- a/nilai-api/src/nilai_api/db/logs.py +++ b/nilai-api/src/nilai_api/db/logs.py @@ -15,17 +15,17 @@ class QueryLog(Base): __tablename__ = "query_logs" - id = Column(Integer, primary_key=True, autoincrement=True) - userid = Column( + id: int = Column(Integer, primary_key=True, autoincrement=True) # type: ignore + userid: str = Column( String(36), ForeignKey(UserModel.userid), nullable=False, index=True - ) - query_timestamp = Column( + ) # type: ignore + query_timestamp: datetime = Column( DateTime, server_default=sqlalchemy.func.now(), nullable=False - ) - model = Column(Text, nullable=False) - prompt_tokens = Column(Integer, nullable=False) - completion_tokens = Column(Integer, nullable=False) - total_tokens = Column(Integer, nullable=False) + ) # type: ignore + model: str = Column(Text, nullable=False) # type: ignore + prompt_tokens: int = Column(Integer, nullable=False) # type: ignore + completion_tokens: int = Column(Integer, nullable=False) # type: ignore + total_tokens: int = Column(Integer, nullable=False) # type: ignore def __repr__(self): return f"" diff --git a/nilai-api/src/nilai_api/db/users.py b/nilai-api/src/nilai_api/db/users.py index 32a1a25a..0eb15764 100644 --- a/nilai-api/src/nilai_api/db/users.py +++ b/nilai-api/src/nilai_api/db/users.py @@ -23,18 +23,20 @@ class UserModel(Base): __tablename__ = "users" - userid = Column(String(50), primary_key=True, index=True) - name = Column(String(100), nullable=False) - email = Column(String(255), unique=True, nullable=False, index=True) - apikey = Column(String(50), unique=True, nullable=False, index=True) - prompt_tokens = Column(Integer, default=0, nullable=False) - completion_tokens = Column(Integer, default=0, nullable=False) - queries = Column(Integer, default=0, nullable=False) - signup_date = Column(DateTime, server_default=sqlalchemy.func.now(), nullable=False) - last_activity = Column(DateTime, nullable=True) - ratelimit_day = Column(Integer, default=1000, nullable=True) - ratelimit_hour = Column(Integer, default=100, nullable=True) - ratelimit_minute = Column(Integer, default=10, nullable=True) + userid: str = Column(String(50), primary_key=True, index=True) # type: ignore + name: str = Column(String(100), nullable=False) # type: ignore + email: str = Column(String(255), unique=True, nullable=False, index=True) # type: ignore + apikey: str = Column(String(50), unique=True, nullable=False, index=True) # type: ignore + prompt_tokens: int = Column(Integer, default=0, nullable=False) # type: ignore + completion_tokens: int = Column(Integer, default=0, nullable=False) # type: ignore + queries: int = Column(Integer, default=0, nullable=False) # type: ignore + signup_date: datetime = Column( + DateTime, server_default=sqlalchemy.func.now(), nullable=False + ) # type: ignore + last_activity: datetime = Column(DateTime, nullable=True) # type: ignore + ratelimit_day: int = Column(Integer, default=1000, nullable=True) # type: ignore + ratelimit_hour: int = Column(Integer, default=100, nullable=True) # type: ignore + ratelimit_minute: int = Column(Integer, default=10, nullable=True) # type: ignore def __repr__(self): return f"" @@ -82,7 +84,7 @@ async def update_last_activity(userid: str): logger.error(f"Error updating last activity: {e}") @staticmethod - async def insert_user(name: str, email: str) -> Dict[str, str]: + async def insert_user(name: str, email: str) -> UserModel: """ Insert a new user into the database. @@ -104,10 +106,10 @@ async def insert_user(name: str, email: str) -> Dict[str, str]: ratelimit_hour=USER_RATE_LIMIT_HOUR, ratelimit_minute=USER_RATE_LIMIT_MINUTE, ) - UserManager.insert_user_model(user) + return await UserManager.insert_user_model(user) @staticmethod - async def insert_user_model(user: UserModel): + async def insert_user_model(user: UserModel) -> UserModel: """ Insert a new user model into the database. @@ -119,6 +121,7 @@ async def insert_user_model(user: UserModel): session.add(user) await session.commit() logger.info(f"User {user.name} added successfully.") + return user except SQLAlchemyError as e: logger.error(f"Error inserting user: {e}") raise @@ -136,9 +139,8 @@ async def check_api_key(api_key: str) -> Optional[UserModel]: """ try: async with get_db_session() as session: - user = await session.execute( - sqlalchemy.select(UserModel).filter(UserModel.apikey == api_key) - ) + query = sqlalchemy.select(UserModel).filter(UserModel.apikey == api_key) # type: ignore + user = await session.execute(query) user = user.scalar_one_or_none() return user except SQLAlchemyError as e: diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index 1bf04739..6b7e0cba 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -59,7 +59,7 @@ async def get_usage(user: UserModel = Depends(get_user)) -> Usage: prompt_tokens=user.prompt_tokens, completion_tokens=user.completion_tokens, total_tokens=user.prompt_tokens + user.completion_tokens, - queries=user.queries, # FIXME this field is not part of Usage + queries=user.queries, # type: ignore # FIXME this field is not part of Usage )