Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions backend/app/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,20 @@
```

EXTREMELY Important notes on syntax!!! (PAY ATTENTION TO THIS):
- Make sure to add colour to the diagram!!! This is extremely critical.
- Make sure to add colour to the diagram!!! This is extremely critical. Not Following these rules will result in a syntax error!
- In Mermaid.js syntax, we cannot include special characters for nodes without being inside quotes! For example: `EX[/api/process (Backend)]:::api` and `API -->|calls Process()| Backend` are two examples of syntax errors. They should be `EX["/api/process (Backend)"]:::api` and `API -->|"calls Process()"| Backend` respectively. Notice the quotes. This is extremely important. Make sure to include quotes for any string that contains special characters.
- In Mermaid.js syntax, you cannot apply a class style directly within a subgraph declaration. For example: `subgraph "Frontend Layer":::frontend` is a syntax error. However, you can apply them to nodes within the subgraph. For example: `Example["Example Node"]:::frontend` is valid, and `class Example1,Example2 frontend` is valid.
- In Mermaid.js syntax, there cannot be spaces in the relationship label names. For example: `A -->| "example relationship" | B` is a syntax error. It should be `A -->|"example relationship"| B`
- In Mermaid.js syntax, connections should be following the format `A -->|"relationship"| B` without spaces around the relationship label. For example: `A -->|"relationship"| B` is valid, and `A -->| "relationship" | B` is a syntax error.
- In Mermaid.js syntax, there cannot be spaces in the relationship label names. For example: `A -->| "example relationship" | B` is a syntax error. It should be `A -->|"example relationship"| B`.
- In Mermaid.js syntax, you cannot give subgraphs an alias like nodes. For example: `subgraph A "Layer A"` is a syntax error. It should be `subgraph "Layer A"`
- In Mermaid.js syntax, you cannot use "direction TD", replace "direction TD" with "direction TB" everwhere neeeded. Very critical information , remember it.
-- Example `subgraph "Layer A" direction TD` is a syntax error. It should be `subgraph "Layer A" direction TB`
- In Mermaid.js syntax, you cannot use special characters in node names and no examples inside the nodes.
-- Example `A[("Example Node", (<text>))] and A[("Example Node"), (<text>)]` is a syntax error. It should be `A["Example Node"]:::example`
- In Mermaid.js syntax, you cannot use special characters in comments and no examples inside the comments.
-- Example `%% This is an example comment with special characters: @#$%^&*()[]{};:'",.<>?` is a syntax error. It should be `%% This is an example comment with special characters: @#$%^&*()[]{};:'",.<>?`
- In Mermaid.js syntax, you cannot add comments after any code line, For example: AI_ModelProviders_Group -->|"Returns Diagram Code"| BE_App %% Simplified return path is a syntax error. It should be AI_ModelProviders_Group -->|"Returns Diagram Code"| BE_App without any comments after it. It is very important, remember it!
- In Mermaid.js syntax, if you encounter the keyword "end" in the code, make sure not to add any comments after it. For example: `end %% This is an example comment` is a syntax error. It should be `end` without any comments after it!
"""
# ^^^ note: ive generated a few diagrams now and claude still writes incorrect mermaid code sometimes. in the future, refer to those generated diagrams and add important instructions to the prompt above to avoid those mistakes. examples are best.

Expand Down
72 changes: 61 additions & 11 deletions backend/app/routers/generate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from app.services.gemini_service import GeminiService
from fastapi import APIRouter, Request, HTTPException
from fastapi.responses import StreamingResponse
from dotenv import load_dotenv
Expand Down Expand Up @@ -25,39 +26,88 @@

# Initialize services
# claude_service = ClaudeService()
# gemini_service = GeminiService()
o4_service = OpenAIo4Service()


# cache github data to avoid double API calls from cost and generate
@lru_cache(maxsize=100)
def get_cached_github_data(username: str, repo: str, github_pat: str | None = None):
def get_cached_github_data(username: str, repo: str, github_pat: str | None = None, branch: str = ""):
# Create a new service instance for each call with the appropriate PAT
current_github_service = GitHubService(pat=github_pat)

default_branch = current_github_service.get_default_branch(username, repo)
if not default_branch:
default_branch = "main" # fallback value
defaultBranch = current_github_service.get_default_branch(username, repo)
if not defaultBranch:
defaultBranch = "main" # fallback value

file_tree = current_github_service.get_github_file_paths_as_list(username, repo)
file_tree = current_github_service.get_github_file_paths_as_list(username, repo, branch)
readme = current_github_service.get_github_readme(username, repo)

return {"default_branch": default_branch, "file_tree": file_tree, "readme": readme}
return {"defaultBranch": defaultBranch, "file_tree": file_tree, "readme": readme}

@lru_cache(maxsize=100)
def get_github_repo_branches(username: str, repo: str, github_pat: str | None = None):
"""Get all branches of a GitHub repository.
"""
# Create a new service instance for each call with the appropriate PAT
current_github_service = GitHubService(pat=github_pat)

branches = current_github_service.get_github_repo_branches(username, repo)
if not branches:
raise HTTPException(status_code=404, detail="No branches found in repository")

# Get the default branch as well to return it
defaultBranch = get_cached_github_data(username, repo, github_pat)["defaultBranch"]
if not defaultBranch:
defaultBranch = "main"
return {
"branches": branches["branches"],
"defaultBranch": defaultBranch,
}

class ApiRequest(BaseModel):
username: str
repo: str
instructions: str = ""
api_key: str | None = None
github_pat: str | None = None
branch: str = ""
page: int = 1
pageSize: int = 100

@router.post("/branches")
async def get_repo_branches(request: Request, body: ApiRequest):
try:
# Validate input
if not body.username or not body.repo:
raise HTTPException(status_code=400, detail="Username and repo are required")

# Create a new service instance with the appropriate PAT
current_github_service = GitHubService(pat=body.github_pat)

# Get branches with pagination
branches_data = current_github_service.get_github_repo_branches(
body.username, body.repo, body.page, body.pageSize
)

# Also get the default branch for compatibility
defaultBranch = current_github_service.get_default_branch(body.username, body.repo)

# Add defaultBranch to the response for compatibility
branches_data["defaultBranch"] = defaultBranch

return branches_data

except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@router.post("/cost")
# @limiter.limit("5/minute") # TEMP: disable rate limit for growth??
async def get_generation_cost(request: Request, body: ApiRequest):
try:
# Get file tree and README content
github_data = get_cached_github_data(body.username, body.repo, body.github_pat)
github_data = get_cached_github_data(body.username, body.repo, body.github_pat, body.branch)
file_tree = github_data["file_tree"]
readme = github_data["readme"]

Expand Down Expand Up @@ -136,9 +186,9 @@ async def event_generator():
try:
# Get cached github data
github_data = get_cached_github_data(
body.username, body.repo, body.github_pat
body.username, body.repo, body.github_pat, body.branch
)
default_branch = github_data["default_branch"]
defaultBranch = github_data["defaultBranch"]
file_tree = github_data["file_tree"]
readme = github_data["readme"]

Expand Down Expand Up @@ -243,7 +293,7 @@ async def event_generator():
return

processed_diagram = process_click_events(
mermaid_code, body.username, body.repo, default_branch
mermaid_code, body.username, body.repo, defaultBranch
)

# Send final result
Expand Down
112 changes: 112 additions & 0 deletions backend/app/services/gemini_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from dotenv import load_dotenv
from app.utils.format_message import format_user_message
import os
import aiohttp
import json
from typing import AsyncGenerator

load_dotenv()

class GeminiService:
def __init__(self):
self.api_key = os.getenv("GEMINI_API_KEY")
self.base_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent"

def call_gemini_api(
self,
system_prompt: str,
data: dict,
api_key: str | None = None,
) -> str:
"""
Makes an API call to Gemini and returns the response.
Args:
system_prompt (str): The instruction/system prompt
data (dict): Dictionary of variables to format into the user message
api_key (str | None): Optional custom API key
Returns:
str: Gemini's response text
"""
user_message = format_user_message(data)
key = api_key or self.api_key
if not key:
raise ValueError("Gemini API key is missing. Please set GEMINI_API_KEY in your environment or provide api_key.")
headers = {
"Content-Type": "application/json",
}
params = {"key": str(key)}
payload = {
"contents": [
{"role": "user", "parts": [{"text": f"{system_prompt}\n{user_message}"}]}
]
}
try:
import requests
response = requests.post(self.base_url, headers=headers, params=params, json=payload)
response.raise_for_status()
result = response.json()
return result["candidates"][0]["content"]["parts"][0]["text"]
except Exception as e:
print(f"Error in Gemini API call: {str(e)}")
raise

async def call_gemini_api_stream(
self,
system_prompt: str,
data: dict,
api_key: str | None = None,
) -> AsyncGenerator[str, None]:
"""
Makes a streaming API call to Gemini and yields the responses.
Args:
system_prompt (str): The instruction/system prompt
data (dict): Dictionary of variables to format into the user message
api_key (str | None): Optional custom API key
Yields:
str: Chunks of Gemini's response text
"""
user_message = format_user_message(data)
key = api_key or self.api_key
if not key:
raise ValueError("Gemini API key is missing. Please set GEMINI_API_KEY in your environment or provide api_key.")
headers = {
"Content-Type": "application/json",
}
params = {"key": str(key)}
payload = {
"contents": [
{"role": "user", "parts": [{"text": f"{system_prompt}\n{user_message}"}]}
]
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(self.base_url, headers=headers, params=params, json=payload) as response:
if response.status != 200:
error_text = await response.text()
print(f"Error response: {error_text}")
raise ValueError(f"Gemini API returned status code {response.status}: {error_text}")
response_text = await response.text()
try:
data = json.loads(response_text)
text = data["candidates"][0]["content"]["parts"][0]["text"]
if text:
yield text
except Exception as e:
print(f"Error parsing Gemini response: {e}")
except aiohttp.ClientError as e:
print(f"Connection error: {str(e)}")
raise ValueError(f"Failed to connect to Gemini API: {str(e)}")
except Exception as e:
print(f"Unexpected error in streaming API call: {str(e)}")
raise

def count_tokens(self, prompt: str) -> int:
"""
Counts the number of tokens in a prompt.
Args:
prompt (str): The prompt to count tokens for
Returns:
int: Estimated number of input tokens
"""
# Gemini does not have a public tokenizer, so we approximate by whitespace splitting
return len(prompt.split())
67 changes: 64 additions & 3 deletions backend/app/services/github_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,64 @@ def _check_repository_exists(self, username, repo):
f"Failed to check repository: {response.status_code}, {response.json()}"
)

def get_github_repo_branches(self, username, repo, page=1, pageSize=100):
"""
Get branches of a GitHub repository with pagination.

Args:
username (str): The GitHub username or organization name
repo (str): The repository name
page (int): Page number to fetch (default: 1)
pageSize (int): Number of branches per page (max: 100, default: 100)

Returns:
dict: A dictionary containing branch names, pagination info, and default branch.
"""
self._check_repository_exists(username, repo)

# Ensure pageSize doesn't exceed GitHub's limit
pageSize = min(pageSize, 100)

api_url = f"https://api.github.com/repos/{username}/{repo}/branches"
params = {
"page": page,
"pageSize": pageSize
}

response = requests.get(api_url, headers=self._get_headers(), params=params)

if response.status_code == 200:
branches_data = response.json()
branches = [branch["name"] for branch in branches_data]

# Parse pagination info from headers
link_header = response.headers.get('Link', '')
has_next = 'rel="next"' in link_header

# Get total count if available (not always provided by GitHub)
total_count = None
if 'Link' in response.headers:
# Try to extract total from last page link if available
import re
last_match = re.search(r'page=(\d+)>; rel="last"', link_header)
if last_match:
last_page = int(last_match.group(1))
# Estimate total (this is approximate)
total_count = (last_page - 1) * pageSize + len(branches_data)

return {
"branches": branches,
"pagination": {
"current_page": page,
"has_next": has_next,
"total_count": total_count
}
}

raise Exception(
f"Failed to fetch branches: {response.status_code}, {response.json()}"
)

def get_default_branch(self, username, repo):
"""Get the default branch of the repository."""
api_url = f"https://api.github.com/repos/{username}/{repo}"
Expand All @@ -107,7 +165,7 @@ def get_default_branch(self, username, repo):
return response.json().get("default_branch")
return None

def get_github_file_paths_as_list(self, username, repo):
def get_github_file_paths_as_list(self, username, repo, branch):
"""
Fetches the file tree of an open-source GitHub repository,
excluding static files and generated code.
Expand Down Expand Up @@ -160,8 +218,11 @@ def should_include_file(path):

return not any(pattern in path.lower() for pattern in excluded_patterns)

# Try to get the default branch first
branch = self.get_default_branch(username, repo)
#if the branch is empty, try to get the default branch
if not branch:
branch = self.get_default_branch(username, repo)

# Finding the file tree for the specified branch
if branch:
api_url = f"https://api.github.com/repos/{
username}/{repo}/git/trees/{branch}?recursive=1"
Expand Down
8 changes: 5 additions & 3 deletions src/app/[username]/[repo]/page.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"use client";

import { useParams } from "next/navigation";
import { useParams, useSearchParams } from "next/navigation";
import MainCard from "~/components/main-card";
import Loading from "~/components/loading";
import MermaidChart from "~/components/mermaid-diagram";
Expand All @@ -13,6 +13,7 @@ import { useStarReminder } from "~/hooks/useStarReminder";
export default function Repo() {
const [zoomingEnabled, setZoomingEnabled] = useState(false);
const params = useParams<{ username: string; repo: string }>();
const branch = useSearchParams().get("branch") ?? "";

// Use the star reminder hook
useStarReminder();
Expand All @@ -31,8 +32,8 @@ export default function Repo() {
handleCloseApiKeyDialog,
handleOpenApiKeyDialog,
handleExportImage,
state,
} = useDiagram(params.username.toLowerCase(), params.repo.toLowerCase());
state
} = useDiagram(params.username.toLowerCase(), params.repo.toLowerCase(), branch);

return (
<div className="flex flex-col items-center p-4">
Expand All @@ -41,6 +42,7 @@ export default function Repo() {
isHome={false}
username={params.username.toLowerCase()}
repo={params.repo.toLowerCase()}
branch={branch}
showCustomization={!loading && !error}
onModify={handleModify}
onRegenerate={handleRegenerate}
Expand Down
Loading