Skip to content

Commit

Permalink
Only redirect to next url relative to current domain
Browse files Browse the repository at this point in the history
  • Loading branch information
debanjum committed Jun 18, 2024
1 parent 86a3505 commit 4daf16e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
12 changes: 6 additions & 6 deletions src/khoj/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
get_or_create_user,
)
from khoj.routers.email import send_welcome_email
from khoj.routers.helpers import update_telemetry_state
from khoj.routers.helpers import get_next_url, update_telemetry_state
from khoj.utils import state

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -94,7 +94,7 @@ async def delete_token(request: Request, token: str):
@auth_router.post("/redirect")
async def auth(request: Request):
form = await request.form()
next_url = request.query_params.get("next", "/")
next_url = get_next_url(request)
for q in request.query_params:
if not q == "next":
next_url += f"&{q}={request.query_params[q]}"
Expand All @@ -104,11 +104,11 @@ async def auth(request: Request):
csrf_token_cookie = request.cookies.get("g_csrf_token")
if not csrf_token_cookie:
logger.info("Missing CSRF token. Redirecting user to login page")
return RedirectResponse(url=f"{next_url}")
return RedirectResponse(url=next_url)
csrf_token_body = form.get("g_csrf_token")
if not csrf_token_body:
logger.info("Missing CSRF token body. Redirecting user to login page")
return RedirectResponse(url=f"{next_url}")
return RedirectResponse(url=next_url)
if csrf_token_cookie != csrf_token_body:
return Response("Invalid CSRF token", status_code=400)

Expand All @@ -130,9 +130,9 @@ async def auth(request: Request):
metadata={"user_id": str(khoj_user.uuid)},
)
logger.log(logging.INFO, f"🥳 New User Created: {khoj_user.uuid}")
return RedirectResponse(url=f"{next_url}", status_code=HTTP_302_FOUND)
return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND)

return RedirectResponse(url=f"{next_url}", status_code=HTTP_302_FOUND)
return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND)


@auth_router.get("/logout")
Expand Down
14 changes: 13 additions & 1 deletion src/khoj/routers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Tuple,
Union,
)
from urllib.parse import parse_qs, urlencode
from urllib.parse import parse_qs, urlencode, urljoin, urlparse

import cron_descriptor
import openai
Expand Down Expand Up @@ -161,6 +161,18 @@ def update_telemetry_state(
]


def get_next_url(request: Request) -> str:
"Construct next url relative to current domain from request"
next_url_param = urlparse(request.query_params.get("next", "/"))
next_path = "/" # default next path
# If relative path or absolute path to current domain
if is_none_or_empty(next_url_param.scheme) or next_url_param.netloc == request.base_url.netloc:
# Use path in next query param
next_path = next_url_param.path
# Construct absolute url using current domain and next path from request
return urljoin(str(request.base_url).rstrip("/"), next_path)


def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
chat_history = ""
for chat in conversation_history.get("chat", [])[-n:]:
Expand Down
3 changes: 2 additions & 1 deletion src/khoj/routers/web_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
get_user_subscription_state,
)
from khoj.database.models import KhojUser
from khoj.routers.helpers import get_next_url
from khoj.routers.notion import get_notion_auth_url
from khoj.routers.twilio import is_twilio_enabled
from khoj.utils import constants, state
Expand Down Expand Up @@ -118,7 +119,7 @@ def chat_page(request: Request):

@web_client.get("/login", response_class=FileResponse)
def login_page(request: Request):
next_url = request.query_params.get("next", "/")
next_url = get_next_url(request)
if request.user.is_authenticated:
return RedirectResponse(url=next_url)
google_client_id = os.environ.get("GOOGLE_CLIENT_ID")
Expand Down

0 comments on commit 4daf16e

Please sign in to comment.