Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from rag_solution.file_management.database import get_db
from rag_solution.services.user_service import UserService

logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=settings.log_level)
logger = logging.getLogger(__name__)

class AuthMiddleware(BaseHTTPMiddleware):
class AuthenticationMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
logger.info(f"AuthMiddleware: Processing request to {request.url.path}")
logger.debug(f"AuthMiddleware: Request headers: {request.headers}")
Expand All @@ -29,7 +29,8 @@ async def dispatch(self, request: Request, call_next):
'id': payload.get('sub'),
'email': payload.get('email'),
'name': payload.get('name'),
'uuid': payload.get('uuid') # Extract UUID from payload
'uuid': payload.get('uuid'), # Extract UUID from payload
'role': payload.get('role')
}
logger.info(f"AuthMiddleware: JWT token validated successfully. User: {request.state.user}")
except jwt.ExpiredSignatureError:
Expand Down
64 changes: 64 additions & 0 deletions backend/core/authorization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import functools
import logging
import re
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
from core.config import settings

logging.basicConfig(level=settings.log_level)
logger = logging.getLogger(__name__)

open_paths = ['/api/auth/login', '/api/auth/callback', '/api/health', '/api/auth/oidc-config', '/api/auth/token', '/api/auth/userinfo']

async def authorize_dependency(request: Request):
"""
Dependency to check if the user is authorized to access the resource.
Uses the RBAC mapping from settings.rbac_mapping to check if the user is authorized to access the resource.

Args:
request (Request): The request object.

Returns:
bool: True if the request is authorized, raises HTTPException otherwise.
"""
logger.info(f"AuthorizationMiddleware: Processing request to {request.url.path} by {request.state.user}")
# print(f"AuthorizationMiddleware: Processing {request.method} request to {request.url.path} by {request.state.user}")
if request.url.path in open_paths:
return True

rrole = request.state.user.get('role')
rpath = request.url.path
try:
if rrole:
for pattern, method in settings.rbac_mapping[rrole].items():
if re.match(pattern, rpath) and request.method in method:
return True
raise HTTPException(status_code=403, detail="Failed to authorize request. {rpath} / {rrole}")

except (KeyError, ValueError) as exc:
logger.warning(f"Failed to authorize request. {rpath} / {rrole}")
raise HTTPException(status_code=403, detail="Failed to authorize request. {rpath} / {rrole}") from exc

def authorize_decorator(role: str):
"""
Decorator to check if the user is authorized to access the resource.

Args:
role (str): The role required to access the resource.

Returns:
function: Goes to the original handler (function) if the request is authorized, raises HTTPException otherwise.
"""
def decorator(handler):
@functools.wraps(handler)
async def wrapper(*args, **kwargs):
request = kwargs['request']
# print(f"AuthorizationDecorator: Processing {request.method} request to {request.url.path} by {request.state.user}")
if request.url.path not in open_paths:
if not request.state.user or request.state.user.get('role') != role:
logger.warning(f"AuthorizationDecorator: Unauthorized request to {request.url.path}")
return JSONResponse(status_code=403, content={"detail": f"User is not authorized to access this resource (requires {role} role)"})
return await handler(*args, **kwargs)
return wrapper
return decorator

27 changes: 25 additions & 2 deletions backend/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Settings(BaseSettings):
react_app_api_url: str

# Logging Level
log_level: Optional[str] = None
log_level: Optional[str] = "INFO"

# File storage path
file_storage_path: str = tempfile.gettempdir()
Expand Down Expand Up @@ -106,11 +106,34 @@ class Settings(BaseSettings):
# JWT settings
jwt_secret_key: str = Field(..., env='JWT_SECRET_KEY')
jwt_algorithm: str = "HS256"

frontend_callback: str = "/callback"

# Role settings
# This is a sample RBAC mapping role / url_patterns / http_methods
rbac_mapping: dict = {
'admin': {
r'^/api/user-collections/(.+)$': ['GET'],
r'^/api/user-collections/(.+)/(.+)$': ['POST', 'DELETE'],
},
'user': {
r'^/api/user-collections/(.+)/(.+)$': ['POST', 'DELETE'],
r'^/api/user-collections/(.+)$': ['GET'],
r'^/api/user-collections/collection/(.+)$': ['GET'],
r'^/api/user-collections/collection/(.+)/users$': ['DELETE'],
r'^/api/collections/(.+)$': ['GET']
},
'guest': {
r'^/api/user-collections$': ['GET', 'POST', 'DELETE', 'PUT'],
r'^/api/collections$': ['GET', 'POST', 'DELETE', 'PUT'],
r'^/api/collection/(.+)$': ['GET', 'POST', 'DELETE', 'PUT']
}
}

class Config:
env_file = ".env"
env_file_encoding = "utf-8"


settings = Settings(
react_app_api_url="http://localhost:3000",
)
20 changes: 20 additions & 0 deletions backend/core/loggingcors_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import logging
from fastapi import Request
from fastapi.middleware.cors import CORSMiddleware
from core.config import settings

logging.basicConfig(level=settings.log_level)
logger = logging.getLogger(__name__)

class LoggingCORSMiddleware(CORSMiddleware):
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
logger.debug(f"CORS Request: method={scope['method']}, path={scope['path']}")
logger.debug(f"CORS Request headers: {scope['headers']}")

response = await super().__call__(scope, receive, send)

if scope["type"] == "http":
logger.debug(f"CORS Response headers: {response.headers}")

return response
74 changes: 38 additions & 36 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,30 @@

from fastapi import FastAPI, Depends, HTTPException, Request, Header
from fastapi.openapi.utils import get_openapi
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.middleware.sessions import SessionMiddleware
from pydantic import BaseModel
from sqlalchemy import inspect, text
from auth.oidc import verify_jwt_token
import jwt

# Import core modules
from core.authentication_middleware import AuthenticationMiddleware
# from backend.core.authorization_decorator import AuthorizationMiddleware
from core.loggingcors_middleware import LoggingCORSMiddleware
from core.authorization import authorize_dependency
from core.config import settings
from auth.oidc import get_current_user, verify_jwt_token
from rag_solution.file_management.database import Base, engine, get_db

# Import all models
from rag_solution.file_management.database import Base, engine, get_db
from rag_solution.models.user import User
from rag_solution.models.collection import Collection
from rag_solution.models.file import File
from rag_solution.models.user_collection import UserCollection
from rag_solution.models.user_team import UserTeam
from rag_solution.models.team import Team


from core.auth_middleware import AuthMiddleware
# Import all routers
from rag_solution.file_management.database import Base, engine
from rag_solution.router.collection_router import router as collection_router
from rag_solution.router.file_router import router as file_router
Expand All @@ -33,24 +37,10 @@
from rag_solution.router.user_team_router import router as user_team_router
from rag_solution.router.health_router import router as health_router
from rag_solution.router.auth_router import router as auth_router
from auth.oidc import get_current_user, oauth

logging.basicConfig(level=settings.log_level)
logger = logging.getLogger(__name__)

class LoggingCORSMiddleware(CORSMiddleware):
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
logger.debug(f"CORS Request: method={scope['method']}, path={scope['path']}")
logger.debug(f"CORS Request headers: {scope['headers']}")

response = await super().__call__(scope, receive, send)

if scope["type"] == "http":
logger.debug(f"CORS Response headers: {response.headers}")

return response

@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting database initialization")
Expand Down Expand Up @@ -123,27 +113,39 @@ async def lifespan(app: FastAPI):
)

# Add Auth middleware
app.add_middleware(AuthMiddleware)

async def auth_dependency(authorization: str = Header(...)):
try:
scheme, token = authorization.split()
if scheme.lower() != 'bearer':
raise HTTPException(status_code=401, detail="Invalid authentication scheme")
payload = verify_jwt_token(token)
return payload
except (jwt.PyJWTError, ValueError):
raise HTTPException(status_code=401, detail="Invalid or expired token")
app.add_middleware(AuthenticationMiddleware)

# Replacing with a decorator on each endpoint to allow fine-grained authorization
# app.add_middleware(AuthorizationMiddleware)

# Already included in AuthMiddleware
# async def auth_dependency(authorization: str = Header(...)):
# try:
# scheme, token = authorization.split()
# if scheme.lower() != 'bearer':
# raise HTTPException(status_code=401, detail="Invalid authentication scheme")
# payload = verify_jwt_token(token)
# return payload
# except (jwt.PyJWTError, ValueError):
# raise HTTPException(status_code=401, detail="Invalid or expired token")

# Include routers
app.include_router(auth_router)
app.include_router(health_router)
app.include_router(collection_router, dependencies=[Depends(auth_dependency)])
app.include_router(file_router, dependencies=[Depends(auth_dependency)])
app.include_router(team_router, dependencies=[Depends(auth_dependency)])
app.include_router(user_router, dependencies=[Depends(auth_dependency)])
app.include_router(user_collection_router, dependencies=[Depends(auth_dependency)])
app.include_router(user_team_router, dependencies=[Depends(auth_dependency)])
app.include_router(collection_router)
app.include_router(file_router)
app.include_router(team_router)
app.include_router(user_router)
app.include_router(user_collection_router, dependencies=[Depends(authorize_dependency)])
# app.include_router(user_collection_router)
app.include_router(user_team_router)
# app.include_router(collection_router, dependencies=[Depends(auth_dependency)])
# app.include_router(file_router, dependencies=[Depends(auth_dependency)])
# app.include_router(team_router, dependencies=[Depends(auth_dependency)])
# app.include_router(user_router, dependencies=[Depends(auth_dependency)])
# app.include_router(user_collection_router, dependencies=[Depends(auth_dependency)])
# app.include_router(user_team_router, dependencies=[Depends(auth_dependency)])
# app.include_router(collection_router, dependencies=[Depends(auth_dependency)])

def custom_openapi():
if app.openapi_schema:
Expand Down
6 changes: 4 additions & 2 deletions backend/rag_solution/router/auth_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,14 @@ async def auth(request: Request, db: Session = Depends(get_db)):
"sub": user['sub'],
"email": user['email'],
"name": user.get('name', 'Unknown'),
"uuid": str(db_user.id) # Include the UUID in the JWT payload
"uuid": str(db_user.id) , # Include the UUID in the JWT payload
"exp": token.get('expires_at'),
"role": "admin"
}
custom_jwt = jwt.encode(custom_jwt_payload, settings.ibm_client_secret, algorithm="HS256")


redirect_url = f"{settings.frontend_url}/?token={custom_jwt}"
redirect_url = f"{settings.frontend_url}{settings.frontend_callback}/?token={custom_jwt}"
logger.info(f"Redirecting to frontend: {redirect_url}")

return RedirectResponse(url=redirect_url)
Expand Down
3 changes: 1 addition & 2 deletions backend/rag_solution/router/collection_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from rag_solution.schemas.user_schema import UserInput
from rag_solution.services.user_service import UserService
from rag_solution.services.collection_service import CollectionService
from rag_solution.services.file_management_service import \
FileManagementService
from rag_solution.services.file_management_service import FileManagementService
import logging

logging.basicConfig(level=logging.INFO)
Expand Down
5 changes: 4 additions & 1 deletion backend/rag_solution/router/user_collection_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from rag_solution.services.user_collection_interaction_service import UserCollectionInteractionService
from rag_solution.schemas.user_collection_schema import UserCollectionOutput, UserCollectionsOutput

from core.authorization import authorize_decorator

router = APIRouter(prefix="/api/user-collections", tags=["user-collections"])

@router.post("/{user_id}/{collection_id}",
Expand Down Expand Up @@ -48,7 +50,8 @@ def remove_user_from_collection(user_id: UUID, collection_id: UUID, db: Session
500: {"description": "Internal server error"}
}
)
def get_user_collections(user_id: UUID, request: Request, db: Session = Depends(get_db)):
@authorize_decorator(role="admin")
async def get_user_collections(user_id: UUID, request: Request, db: Session = Depends(get_db)):
if not hasattr(request.state, 'user') or request.state.user['uuid'] != str(user_id):
raise HTTPException(status_code=403, detail="Not authorized to access this resource")

Expand Down
6 changes: 6 additions & 0 deletions webui/jsconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"compilerOptions": {
"baseUrl": "."
},
"include": ["src"]
}
3 changes: 2 additions & 1 deletion webui/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
"start": "react-scripts start",
"build": "DISABLE_ESLINT_PLUGIN=true react-scripts build",
"test": "react-scripts test",
"eject": "react-scripts eject"
"eject": "react-scripts eject",
"dev": "WDS_SOCKET_PORT=3000 CHOKIDAR_USEPOLLING=true WATCHPACK_POLLING=true FAST_REFRESH=false react-scripts start"
},
"proxy": "http://localhost:8000",
"eslintConfig": {
Expand Down
17 changes: 9 additions & 8 deletions webui/src/App.css
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
min-height: 100vh;
}

.bx--content {
margin-top: 3rem; /* Adjust based on your Header height */
flex-grow: 1;
background-color: #f4f4f4;
}

.page-container {
max-width: 1200px;
margin: 0 auto;
display: flex;
flex-direction: column;
flex-grow: 1;
padding: 2rem;
}

/* .bx--content {
margin-top: 3rem; Adjust based on your Header height
flex-grow: 1;
background-color: #f4f4f4;
} */

@media (max-width: 1056px) {
.page-container {
padding: 1rem;
Expand Down
Loading
Loading