Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions src/gitingest/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from config import TMP_BASE_PATH
from gitingest.exceptions import InvalidPatternError
from gitingest.ignore_patterns import DEFAULT_IGNORE_PATTERNS
from gitingest.repository_clone import _check_repo_exists
from gitingest.repository_clone import _check_repo_exists, fetch_remote_branch_list

HEX_DIGITS: set[str] = set(string.hexdigits)

Expand Down Expand Up @@ -168,18 +168,47 @@ async def _parse_repo_source(source: str) -> dict[str, Any]:
parsed["type"] = possible_type

# Commit or branch
commit_or_branch = remaining_parts.pop(0)
commit_or_branch = remaining_parts[0]
if _is_valid_git_commit_hash(commit_or_branch):
parsed["commit"] = commit_or_branch
else:
parsed["branch"] = commit_or_branch
parsed["subpath"] += "/".join(remaining_parts[1:])

# Subpath if anything left
if remaining_parts:
else:
parsed["branch"] = await _configure_branch_and_subpath(remaining_parts, url)
parsed["subpath"] += "/".join(remaining_parts)

return parsed

async def _configure_branch_and_subpath(remaining_parts: list[str],url: str) -> str | None:
"""
Find the branch name from the remaining parts of the URL path.
Parameters
----------
remaining_parts : list[str]
List of path parts extracted from the URL.
url : str
The repository URL to determine branches.

Returns
-------
str (branch name) or None

"""
try:
# Fetch the list of branches from the remote repository
branches: list[str] = await fetch_remote_branch_list(url)
except Exception as e:
print(f"Warning: Failed to fetch branch list: {str(e)}")
return remaining_parts.pop(0) if remaining_parts else None

branch = []

while remaining_parts:
branch.append(remaining_parts.pop(0))
branch_name = "/".join(branch)
if branch_name in branches:
return branch_name

return None

def _is_valid_git_commit_hash(commit: str) -> bool:
"""
Expand Down
30 changes: 28 additions & 2 deletions src/gitingest/repository_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from gitingest.utils import async_timeout

CLONE_TIMEOUT: int = 20
TIMEOUT: int = 20


@dataclass
Expand Down Expand Up @@ -34,7 +34,7 @@ class CloneConfig:
branch: str | None = None


@async_timeout(CLONE_TIMEOUT)
@async_timeout(TIMEOUT)
async def clone_repo(config: CloneConfig) -> tuple[bytes, bytes]:
"""
Clone a repository to a local path based on the provided configuration.
Expand Down Expand Up @@ -141,6 +141,32 @@ async def _check_repo_exists(url: str) -> bool:
raise RuntimeError(f"Unexpected status code: {status_code}")


@async_timeout(TIMEOUT)
async def fetch_remote_branch_list(url: str) -> list[str]:
"""
Get the list of branches from the remote repo.

Parameters
----------
url : str
The URL of the repository.

Returns
-------
list[str]
list of the branches in the remote repository
"""
fetch_branches_command = ["git", "ls-remote", "--heads", url]
stdout, stderr = await _run_git_command(*fetch_branches_command)
stdout_decoded = stdout.decode()

return [
line.split('refs/heads/', 1)[1]
for line in stdout_decoded.splitlines()
if line.strip() and 'refs/heads/' in line
]


async def _run_git_command(*args: str) -> tuple[bytes, bytes]:
"""
Execute a Git command asynchronously and captures its output.
Expand Down
72 changes: 59 additions & 13 deletions tests/query_parser/test_query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from pathlib import Path

import pytest
from unittest.mock import patch, AsyncMock
from gitingest.repository_clone import _check_repo_exists, fetch_remote_branch_list

from gitingest.ignore_patterns import DEFAULT_IGNORE_PATTERNS
from gitingest.query_parser import _parse_patterns, _parse_repo_source, parse_query
Expand Down Expand Up @@ -96,18 +98,21 @@ async def test_parse_query_invalid_pattern() -> None:
with pytest.raises(ValueError, match="Pattern.*contains invalid characters"):
await parse_query(url, max_file_size=50, from_web=True, include_patterns="*.py;rm -rf")


async def test_parse_url_with_subpaths() -> None:
"""
Test `_parse_repo_source` with a URL containing a branch and subpath.
Verifies that user name, repository name, branch, and subpath are correctly extracted.
"""
url = "https://github.com/user/repo/tree/main/subdir/file"
result = await _parse_repo_source(url)
assert result["user_name"] == "user"
assert result["repo_name"] == "repo"
assert result["branch"] == "main"
assert result["subpath"] == "/subdir/file"
with patch('gitingest.repository_clone._run_git_command', new_callable=AsyncMock) as mock_run_git_command:
mock_run_git_command.return_value = (b"refs/heads/main\nrefs/heads/dev\nrefs/heads/feature-branch\n", b"")
with patch('gitingest.repository_clone.fetch_remote_branch_list', new_callable=AsyncMock) as mock_fetch_branches:
mock_fetch_branches.return_value = ["main", "dev", "feature-branch"]
result = await _parse_repo_source(url)
assert result["user_name"] == "user"
assert result["repo_name"] == "repo"
assert result["branch"] == "main"
assert result["subpath"] == "/subdir/file"


async def test_parse_url_invalid_repo_structure() -> None:
Expand Down Expand Up @@ -222,15 +227,18 @@ async def test_parse_url_branch_and_commit_distinction() -> None:
url_branch = "https://github.com/user/repo/tree/main"
url_commit = "https://github.com/user/repo/tree/abcd1234abcd1234abcd1234abcd1234abcd1234"

result_branch = await _parse_repo_source(url_branch)
result_commit = await _parse_repo_source(url_commit)

assert result_branch["branch"] == "main"
assert result_branch["commit"] is None
with patch('gitingest.repository_clone._run_git_command', new_callable=AsyncMock) as mock_run_git_command:
mock_run_git_command.return_value = (b"refs/heads/main\nrefs/heads/dev\nrefs/heads/feature-branch\n", b"")
with patch('gitingest.repository_clone.fetch_remote_branch_list', new_callable=AsyncMock) as mock_fetch_branches:
mock_fetch_branches.return_value = ["main", "dev", "feature-branch"]

assert result_commit["branch"] is None
assert result_commit["commit"] == "abcd1234abcd1234abcd1234abcd1234abcd1234"
result_branch = await _parse_repo_source(url_branch)
result_commit = await _parse_repo_source(url_commit)
assert result_branch["branch"] == "main"
assert result_branch["commit"] is None

assert result_commit["branch"] is None
assert result_commit["commit"] == "abcd1234abcd1234abcd1234abcd1234abcd1234"

async def test_parse_query_uuid_uniqueness() -> None:
"""
Expand Down Expand Up @@ -274,3 +282,41 @@ async def test_parse_query_with_branch() -> None:
assert result["branch"] == "2.2.x"
assert result["commit"] is None
assert result["type"] == "blob"

@pytest.mark.asyncio
@pytest.mark.parametrize("url, expected_branch, expected_subpath", [
("https://github.com/user/repo/tree/main/src", "main", "/src"),
("https://github.com/user/repo/tree/fix1", "fix1", "/"),
("https://github.com/user/repo/tree/nonexistent-branch/src", "nonexistent-branch", "/src"),
])
async def test_parse_repo_source_with_failed_git_command(url, expected_branch, expected_subpath):
"""
Test `_parse_repo_source` when git command fails.
Verifies that the function returns the first path component as the branch.
"""
with patch('gitingest.repository_clone.fetch_remote_branch_list', new_callable=AsyncMock) as mock_fetch_branches:
mock_fetch_branches.side_effect = Exception("Failed to fetch branch list")

result = await _parse_repo_source(url)

assert result["branch"] == expected_branch
assert result["subpath"] == expected_subpath

@pytest.mark.asyncio
@pytest.mark.parametrize("url, expected_branch, expected_subpath", [
("https://github.com/user/repo/tree/feature/fix1/src", "feature/fix1", "/src"),
("https://github.com/user/repo/tree/main/src", "main", "/src"),
("https://github.com/user/repo", None, "/"), # No
("https://github.com/user/repo/tree/nonexistent-branch/src", None, "/"), # Non-existent branch
("https://github.com/user/repo/tree/fix", "fix", "/"),
])
async def test_parse_repo_source_with_various_url_patterns(url, expected_branch, expected_subpath):
with patch('gitingest.repository_clone._run_git_command', new_callable=AsyncMock) as mock_run_git_command, \
patch('gitingest.repository_clone.fetch_remote_branch_list', new_callable=AsyncMock) as mock_fetch_branches:

mock_run_git_command.return_value = (b"refs/heads/feature/fix1\nrefs/heads/main\nrefs/heads/feature-branch\nrefs/heads/fix\n", b"")
mock_fetch_branches.return_value = ["feature/fix1", "main", "feature-branch"]

result = await _parse_repo_source(url)
assert result["branch"] == expected_branch
assert result["subpath"] == expected_subpath