Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions nilai-api/examples/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@ async def main():
print(f"Alice's details: {alice}")

# Check API key
user_name = await UserManager.check_api_key(bob["apikey"])
user_name = await UserManager.check_api_key(bob.apikey)
print(f"API key validation: {user_name}")

# Update and retrieve token usage
await UserManager.update_token_usage(
bob["userid"], prompt_tokens=50, completion_tokens=20
bob.userid, prompt_tokens=50, completion_tokens=20
)
usage = await UserManager.get_user_token_usage(bob["userid"])
usage = await UserManager.get_user_token_usage(bob.userid)
print(f"Bob's token usage: {usage}")

# Log a query
await QueryLogManager.log_query(
userid=bob["userid"],
userid=bob.userid,
model="gpt-3.5-turbo",
prompt_tokens=8,
completion_tokens=7,
Expand Down
27 changes: 16 additions & 11 deletions nilai-api/src/nilai_api/auth/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def serialize_sign_doc(sign_doc: dict) -> bytes:


def keplr_validate(
message: str, header: dict, payload: dict, signature: str
message: str, header: dict, payload: dict, signature: bytes
) -> JWTAuthResult:
# Validate the algorithm
if header["alg"] != "ES256":
Expand Down Expand Up @@ -95,13 +95,15 @@ def keplr_validate(
serialized_sign_doc,
)

return JWTAuthResult(
pub_key=payload.get("pub_key"), user_address=payload.get("user_address")
)
pub_key = payload.get("pub_key")
user_address = payload.get("user_address")
if not pub_key or not user_address:
raise ValueError("Invalid payload, missing pub_key or user_address")
return JWTAuthResult(pub_key=pub_key, user_address=user_address)


def metamask_validate(
message: str, header: dict, payload: dict, signature: str
message: str, header: dict, payload: dict, signature: bytes
) -> JWTAuthResult:
# Validate the algorithm
if header["alg"] != "ES256K":
Expand All @@ -110,20 +112,23 @@ def metamask_validate(
if payload.get("exp") and payload["exp"] < int(time.time()):
raise ValueError("Token has expired")
w3 = Web3(Web3.HTTPProvider(""))
message = encode_defunct(text=message)
signable_message = encode_defunct(text=message)
address = w3.eth.account.recover_message(
message, signature=HexBytes("0x" + signature.hex())
signable_message, signature=HexBytes("0x" + signature.hex())
)

if address.lower() != payload.get("user_address"):
raise ValueError("Invalid signature")

return JWTAuthResult(
pub_key=payload.get("pub_key"), user_address=payload.get("user_address")
)
pub_key = payload.get("pub_key")
user_address = payload.get("user_address")
if not pub_key or not user_address:
raise ValueError("Invalid payload, missing pub_key or user_address")

return JWTAuthResult(pub_key=pub_key, user_address=user_address)


def extract_fields(jwt: str) -> tuple[str, dict, dict, str]:
def extract_fields(jwt: str) -> tuple[str, dict, dict, bytes]:
# Split and decode JWT components
header_b64, payload_b64, signature_b64 = jwt.split(".")
if not all([header_b64, payload_b64, signature_b64]):
Expand Down
18 changes: 9 additions & 9 deletions nilai-api/src/nilai_api/db/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
class QueryLog(Base):
__tablename__ = "query_logs"

id = Column(Integer, primary_key=True, autoincrement=True)
userid = Column(
id: int = Column(Integer, primary_key=True, autoincrement=True) # type: ignore
userid: str = Column(
String(36), ForeignKey(UserModel.userid), nullable=False, index=True
)
query_timestamp = Column(
) # type: ignore
query_timestamp: datetime = Column(
DateTime, server_default=sqlalchemy.func.now(), nullable=False
)
model = Column(Text, nullable=False)
prompt_tokens = Column(Integer, nullable=False)
completion_tokens = Column(Integer, nullable=False)
total_tokens = Column(Integer, nullable=False)
) # type: ignore
model: str = Column(Text, nullable=False) # type: ignore
prompt_tokens: int = Column(Integer, nullable=False) # type: ignore
completion_tokens: int = Column(Integer, nullable=False) # type: ignore
total_tokens: int = Column(Integer, nullable=False) # type: ignore

def __repr__(self):
return f"<QueryLog(userid={self.userid}, query_timestamp={self.query_timestamp}, total_tokens={self.total_tokens})>"
Expand Down
38 changes: 20 additions & 18 deletions nilai-api/src/nilai_api/db/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@
class UserModel(Base):
__tablename__ = "users"

userid = Column(String(50), primary_key=True, index=True)
name = Column(String(100), nullable=False)
email = Column(String(255), unique=True, nullable=False, index=True)
apikey = Column(String(50), unique=True, nullable=False, index=True)
prompt_tokens = Column(Integer, default=0, nullable=False)
completion_tokens = Column(Integer, default=0, nullable=False)
queries = Column(Integer, default=0, nullable=False)
signup_date = Column(DateTime, server_default=sqlalchemy.func.now(), nullable=False)
last_activity = Column(DateTime, nullable=True)
ratelimit_day = Column(Integer, default=1000, nullable=True)
ratelimit_hour = Column(Integer, default=100, nullable=True)
ratelimit_minute = Column(Integer, default=10, nullable=True)
userid: str = Column(String(50), primary_key=True, index=True) # type: ignore
name: str = Column(String(100), nullable=False) # type: ignore
email: str = Column(String(255), unique=True, nullable=False, index=True) # type: ignore
apikey: str = Column(String(50), unique=True, nullable=False, index=True) # type: ignore
prompt_tokens: int = Column(Integer, default=0, nullable=False) # type: ignore
completion_tokens: int = Column(Integer, default=0, nullable=False) # type: ignore
queries: int = Column(Integer, default=0, nullable=False) # type: ignore
signup_date: datetime = Column(
DateTime, server_default=sqlalchemy.func.now(), nullable=False
) # type: ignore
last_activity: datetime = Column(DateTime, nullable=True) # type: ignore
ratelimit_day: int = Column(Integer, default=1000, nullable=True) # type: ignore
ratelimit_hour: int = Column(Integer, default=100, nullable=True) # type: ignore
ratelimit_minute: int = Column(Integer, default=10, nullable=True) # type: ignore

def __repr__(self):
return f"<User(userid={self.userid}, name={self.name}, email={self.email})>"
Expand Down Expand Up @@ -82,7 +84,7 @@ async def update_last_activity(userid: str):
logger.error(f"Error updating last activity: {e}")

@staticmethod
async def insert_user(name: str, email: str) -> Dict[str, str]:
async def insert_user(name: str, email: str) -> UserModel:
"""
Insert a new user into the database.

Expand All @@ -104,10 +106,10 @@ async def insert_user(name: str, email: str) -> Dict[str, str]:
ratelimit_hour=USER_RATE_LIMIT_HOUR,
ratelimit_minute=USER_RATE_LIMIT_MINUTE,
)
UserManager.insert_user_model(user)
return await UserManager.insert_user_model(user)

@staticmethod
async def insert_user_model(user: UserModel):
async def insert_user_model(user: UserModel) -> UserModel:
"""
Insert a new user model into the database.

Expand All @@ -119,6 +121,7 @@ async def insert_user_model(user: UserModel):
session.add(user)
await session.commit()
logger.info(f"User {user.name} added successfully.")
return user
except SQLAlchemyError as e:
logger.error(f"Error inserting user: {e}")
raise
Expand All @@ -136,9 +139,8 @@ async def check_api_key(api_key: str) -> Optional[UserModel]:
"""
try:
async with get_db_session() as session:
user = await session.execute(
sqlalchemy.select(UserModel).filter(UserModel.apikey == api_key)
)
query = sqlalchemy.select(UserModel).filter(UserModel.apikey == api_key) # type: ignore
user = await session.execute(query)
user = user.scalar_one_or_none()
return user
except SQLAlchemyError as e:
Expand Down
2 changes: 1 addition & 1 deletion nilai-api/src/nilai_api/routers/private.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def get_usage(user: UserModel = Depends(get_user)) -> Usage:
prompt_tokens=user.prompt_tokens,
completion_tokens=user.completion_tokens,
total_tokens=user.prompt_tokens + user.completion_tokens,
queries=user.queries, # FIXME this field is not part of Usage
queries=user.queries, # type: ignore # FIXME this field is not part of Usage
)


Expand Down