Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .env.sample
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
HUGGINGFACE_API_TOKEN="<your_huggingface_token>"
HUGGINGFACE_API_TOKEN="<your_huggingface_token>"
OMP_NUM_THREADS=8
MKL_NUM_THREADS=8
NUMEXPR_NUM_THREADS=8
OPENBLAS_NUM_THREADS=8
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@

Copy the `.env.sample` to `.env` to and replace the value of the `HUGGINGFACE_API_TOKEN` with the appropriate value. It is required to download Llama3.2 1B.

For development environments:
```shell
docker compose up --build web
docker compose -f docker-compose.yml -f docker-compose.dev.yml up --build nilai
```

For production environments:
```shell
docker compose -f docker-compose.yml -f docker-compose.prod.yml up -d
```

```
uv run gunicorn -c gunicorn.conf.py nilai.__main__:app
```
2 changes: 2 additions & 0 deletions caddy/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
caddy_config/
caddy_data/
10 changes: 10 additions & 0 deletions caddy/Caddyfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
(ssl_config) {
tls {
protocols tls1.2 tls1.3
}
}

https://nilai.sandbox.nilogy.xyz {
import ssl_config
reverse_proxy nilai:8443
}
4 changes: 4 additions & 0 deletions docker-compose.dev.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
services:
nilai:
ports:
- "8080:8080"
21 changes: 21 additions & 0 deletions docker-compose.prod.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
services:
nilai:
networks:
- proxy_net
caddy:
image: caddy:latest
container_name: caddy
restart: unless-stopped
networks:
- proxy_net
ports:
- "80:80"
- "443:443"
- "443:443/udp"
volumes:
- ./caddy/Caddyfile:/etc/caddy/Caddyfile
- ./caddy/caddy_data:/data
- ./caddy/caddy_config:/config

networks:
proxy_net:
3 changes: 0 additions & 3 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@ services:
build:
context: .
dockerfile: docker/Dockerfile
ports:
- "12345:12345"
volumes:
- ${PWD}/db/:/app/db/ # sqlite database for users
- hugging_face_models:/root/.cache/huggingface # cache models

volumes:
hugging_face_models:
18 changes: 10 additions & 8 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
FROM python:3.12-slim

COPY --link nilai /app/nilai
COPY pyproject.toml uv.lock .env /app/
COPY pyproject.toml uv.lock .env gunicorn.conf.py /app/

WORKDIR /app

RUN pip install uv
RUN uv sync
RUN apt-get update && \
apt-get install build-essential certbot -y && \
apt-get clean && \
apt-get autoremove && \
rm -rf /var/lib/apt/lists/* && \
pip install uv && \
uv sync

EXPOSE 12345
EXPOSE 8080 8443

# ENTRYPOINT ["uv", "run", "fastapi", "run", "nilai/main.py"]
# CMD ["--host", "0.0.0.0", "--port", "12345"]

CMD ["uv", "run", "fastapi", "run", "nilai/main.py", "--host", "0.0.0.0", "--port", "12345"]
CMD ["uv", "run", "gunicorn", "-c", "gunicorn.conf.py", "nilai.__main__:app"]
3 changes: 2 additions & 1 deletion docker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ docker build -t nillion/nilai:latest -f docker/Dockerfile .


docker run \
-p 12345:12345 \
-p 8080:8080 \
-p 8443:8443 \
-v hugging_face_models:/root/.cache/huggingface \
-v $(pwd)/users.sqlite:/app/users.sqlite \
nillion/nilai:latest
Expand Down
16 changes: 16 additions & 0 deletions gunicorn.conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# gunicorn.config.py

# Bind to address and port
bind = ["0.0.0.0:8080", "0.0.0.0:8443"]

# Set the number of workers (2)
workers = 2

# Set the number of threads per worker (16)
threads = 16

# Set the timeout (120 seconds)
timeout = 120

# Set the worker class to UvicornWorker for async handling
worker_class = "uvicorn.workers.UvicornWorker"
21 changes: 21 additions & 0 deletions nilai/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import uvicorn

from nilai.app import app


def run_uvicorn():
"""
Function to run the app with Uvicorn for debugging.
"""
uvicorn.run(
app,
host="0.0.0.0", # Listen on all interfaces
port=8080, # Use the desired port
reload=True, # Enable auto-reload for development
# ssl_certfile=SSL_CERTFILE,
# ssl_keyfile=SSL_KEYFILE,
)


if __name__ == "__main__":
run_uvicorn()
19 changes: 4 additions & 15 deletions nilai/main.py → nilai/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,13 @@
"name": "Model",
"description": "Model information",
},
{
"name": "Usage",
"description": "User token usage",
},
],
)


app.include_router(public.router)
app.include_router(private.router, dependencies=[Depends(get_user)])

if __name__ == "__main__":
import uvicorn

# Path to your SSL certificate and key files
# SSL_CERTFILE = "/path/to/certificate.pem" # Replace with your certificate file path
# SSL_KEYFILE = "/path/to/private-key.pem" # Replace with your private key file path

uvicorn.run(
app,
host="0.0.0.0", # Listen on all interfaces
port=12345, # Use port 8443 for HTTPS
# ssl_certfile=SSL_CERTFILE,
# ssl_keyfile=SSL_KEYFILE,
)
56 changes: 42 additions & 14 deletions nilai/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class User(Base):
userid = Column(String(36), primary_key=True, index=True)
name = Column(String(100), nullable=False)
apikey = Column(String(36), unique=True, nullable=False, index=True)
input_tokens = Column(Integer, default=0, nullable=False)
generated_tokens = Column(Integer, default=0, nullable=False)
prompt_tokens = Column(Integer, default=0, nullable=False)
completion_tokens = Column(Integer, default=0, nullable=False)

def __repr__(self):
return f"<User(userid={self.userid}, name={self.name})>"
Expand Down Expand Up @@ -146,7 +146,7 @@ def insert_user(name: str) -> Dict[str, str]:
raise

@staticmethod
def check_api_key(api_key: str) -> Optional[str]:
def check_api_key(api_key: str) -> Optional[dict]:
"""
Validate an API key.

Expand All @@ -159,33 +159,59 @@ def check_api_key(api_key: str) -> Optional[str]:
try:
with get_db_session() as session:
user = session.query(User).filter(User.apikey == api_key).first()
return user.name if user else None # type: ignore
return {"name": user.name, "userid": user.userid} if user else None # type: ignore
except SQLAlchemyError as e:
logger.error(f"Error checking API key: {e}")
return None

@staticmethod
def update_token_usage(userid: str, input_tokens: int, generated_tokens: int):
def update_token_usage(userid: str, prompt_tokens: int, completion_tokens: int):
"""
Update token usage for a specific user.

Args:
userid (str): User's unique ID
input_tokens (int): Number of input tokens
generated_tokens (int): Number of generated tokens
prompt_tokens (int): Number of input tokens
completion_tokens (int): Number of generated tokens
"""
try:
with get_db_session() as session:
user = session.query(User).filter(User.userid == userid).first()
if user:
user.input_tokens += input_tokens # type: ignore
user.generated_tokens += generated_tokens # type: ignore
user.prompt_tokens += prompt_tokens # type: ignore
user.completion_tokens += completion_tokens # type: ignore
logger.info(f"Updated token usage for user {userid}")
else:
logger.warning(f"User {userid} not found")
except SQLAlchemyError as e:
logger.error(f"Error updating token usage: {e}")

@staticmethod
def get_token_usage(
userid: str,
) -> (
Dict[str, Any] | None
): # -> dict[str, Any] | None:# -> dict[str, Any] | None:# -> dict[str, Any] | None:# -> dict[str, Any] | None:# -> dict[str, Any] | None:
"""
Get token usage for a specific user.

Args:
userid (str): User's unique ID
"""
try:
with get_db_session() as session:
user = session.query(User).filter(User.userid == userid).first()
if user:
return {
"prompt_tokens": user.prompt_tokens,
"completion_tokens": user.completion_tokens,
"total_tokens": user.prompt_tokens + user.completion_tokens,
}
else:
logger.warning(f"User {userid} not found")
except SQLAlchemyError as e:
logger.error(f"Error updating token usage: {e}")

@staticmethod
def get_all_users() -> Optional[List[UserData]]:
"""
Expand All @@ -202,8 +228,8 @@ def get_all_users() -> Optional[List[UserData]]:
userid=user.userid, # type: ignore
name=user.name, # type: ignore
apikey=user.apikey, # type: ignore
input_tokens=user.input_tokens, # type: ignore
generated_tokens=user.generated_tokens, # type: ignore
input_tokens=user.prompt_tokens, # type: ignore
generated_tokens=user.completion_tokens, # type: ignore
)
for user in users
]
Expand All @@ -227,8 +253,8 @@ def get_user_token_usage(userid: str) -> Optional[Dict[str, int]]:
user = session.query(User).filter(User.userid == userid).first()
if user:
return {
"input_tokens": user.input_tokens,
"generated_tokens": user.generated_tokens,
"prompt_tokens": user.prompt_tokens,
"completion_tokens": user.completion_tokens,
} # type: ignore
return None
except SQLAlchemyError as e:
Expand All @@ -255,6 +281,8 @@ def get_user_token_usage(userid: str) -> Optional[Dict[str, int]]:
print(f"API key validation: {user_name}")

# Update and retrieve token usage
UserManager.update_token_usage(bob["userid"], input_tokens=50, generated_tokens=20)
UserManager.update_token_usage(
bob["userid"], prompt_tokens=50, completion_tokens=20
)
usage = UserManager.get_user_token_usage(bob["userid"])
print(f"Bob's token usage: {usage}")
Loading