Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions src/gitingest/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import re
import string
import uuid
import warnings
from pathlib import Path
from typing import Any
from urllib.parse import unquote, urlparse

from config import TMP_BASE_PATH
from gitingest.exceptions import InvalidPatternError
from gitingest.ignore_patterns import DEFAULT_IGNORE_PATTERNS
from gitingest.repository_clone import _check_repo_exists
from gitingest.repository_clone import _check_repo_exists, fetch_remote_branch_list

HEX_DIGITS: set[str] = set(string.hexdigits)

Expand Down Expand Up @@ -169,19 +170,48 @@ async def _parse_repo_source(source: str) -> dict[str, Any]:
parsed["type"] = possible_type

# Commit or branch
commit_or_branch = remaining_parts.pop(0)
commit_or_branch = remaining_parts[0]
if _is_valid_git_commit_hash(commit_or_branch):
parsed["commit"] = commit_or_branch
parsed["subpath"] += "/".join(remaining_parts[1:])
else:
parsed["branch"] = commit_or_branch

# Subpath if anything left
if remaining_parts:
parsed["branch"] = await _configure_branch_and_subpath(remaining_parts, url)
parsed["subpath"] += "/".join(remaining_parts)

return parsed


async def _configure_branch_and_subpath(remaining_parts: list[str], url: str) -> str | None:
"""
Configure the branch and subpath based on the remaining parts of the URL.
Parameters
----------
remaining_parts : list[str]
The remaining parts of the URL path.
url : str
The URL of the repository.
Returns
-------
str | None
The branch name if found, otherwise None.

"""
try:
# Fetch the list of branches from the remote repository
branches: list[str] = await fetch_remote_branch_list(url)
except RuntimeError as e:
warnings.warn(f"Warning: Failed to fetch branch list: {str(e)}")
return remaining_parts.pop(0)

branch = []
while remaining_parts:
branch.append(remaining_parts.pop(0))
branch_name = "/".join(branch)
if branch_name in branches:
return branch_name

return None


def _is_valid_git_commit_hash(commit: str) -> bool:
"""
Validate if the provided string is a valid Git commit hash.
Expand Down
28 changes: 26 additions & 2 deletions src/gitingest/repository_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from gitingest.utils import async_timeout

CLONE_TIMEOUT: int = 20
TIMEOUT: int = 20


@dataclass
Expand Down Expand Up @@ -34,7 +34,7 @@ class CloneConfig:
branch: str | None = None


@async_timeout(CLONE_TIMEOUT)
@async_timeout(TIMEOUT)
async def clone_repo(config: CloneConfig) -> tuple[bytes, bytes]:
"""
Clone a repository to a local path based on the provided configuration.
Expand Down Expand Up @@ -141,6 +141,30 @@ async def _check_repo_exists(url: str) -> bool:
raise RuntimeError(f"Unexpected status code: {status_code}")


@async_timeout(TIMEOUT)
async def fetch_remote_branch_list(url: str) -> list[str]:
"""
Fetch the list of branches from a remote Git repository.
Parameters
----------
url : str
The URL of the Git repository to fetch branches from.
Returns
-------
list[str]
A list of branch names available in the remote repository.
"""
fetch_branches_command = ["git", "ls-remote", "--heads", url]
stdout, _ = await _run_git_command(*fetch_branches_command)
stdout_decoded = stdout.decode()

return [
line.split("refs/heads/", 1)[1]
for line in stdout_decoded.splitlines()
if line.strip() and "refs/heads/" in line
]


async def _run_git_command(*args: str) -> tuple[bytes, bytes]:
"""
Execute a Git command asynchronously and captures its output.
Expand Down
87 changes: 76 additions & 11 deletions tests/query_parser/test_query_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" Tests for the query_parser module. """

from pathlib import Path
from unittest.mock import AsyncMock, patch

import pytest

Expand Down Expand Up @@ -109,11 +110,17 @@ async def test_parse_url_with_subpaths() -> None:
Verifies that user name, repository name, branch, and subpath are correctly extracted.
"""
url = "https://github.com/user/repo/tree/main/subdir/file"
result = await _parse_repo_source(url)
assert result["user_name"] == "user"
assert result["repo_name"] == "repo"
assert result["branch"] == "main"
assert result["subpath"] == "/subdir/file"
with patch("gitingest.repository_clone._run_git_command", new_callable=AsyncMock) as mock_run_git_command:
mock_run_git_command.return_value = (b"refs/heads/main\nrefs/heads/dev\nrefs/heads/feature-branch\n", b"")
with patch(
"gitingest.repository_clone.fetch_remote_branch_list", new_callable=AsyncMock
) as mock_fetch_branches:
mock_fetch_branches.return_value = ["main", "dev", "feature-branch"]
result = await _parse_repo_source(url)
assert result["user_name"] == "user"
assert result["repo_name"] == "repo"
assert result["branch"] == "main"
assert result["subpath"] == "/subdir/file"


async def test_parse_url_invalid_repo_structure() -> None:
Expand Down Expand Up @@ -228,14 +235,20 @@ async def test_parse_url_branch_and_commit_distinction() -> None:
url_branch = "https://github.com/user/repo/tree/main"
url_commit = "https://github.com/user/repo/tree/abcd1234abcd1234abcd1234abcd1234abcd1234"

result_branch = await _parse_repo_source(url_branch)
result_commit = await _parse_repo_source(url_commit)
with patch("gitingest.repository_clone._run_git_command", new_callable=AsyncMock) as mock_run_git_command:
mock_run_git_command.return_value = (b"refs/heads/main\nrefs/heads/dev\nrefs/heads/feature-branch\n", b"")
with patch(
"gitingest.repository_clone.fetch_remote_branch_list", new_callable=AsyncMock
) as mock_fetch_branches:
mock_fetch_branches.return_value = ["main", "dev", "feature-branch"]

assert result_branch["branch"] == "main"
assert result_branch["commit"] is None
result_branch = await _parse_repo_source(url_branch)
result_commit = await _parse_repo_source(url_commit)
assert result_branch["branch"] == "main"
assert result_branch["commit"] is None

assert result_commit["branch"] is None
assert result_commit["commit"] == "abcd1234abcd1234abcd1234abcd1234abcd1234"
assert result_commit["branch"] is None
assert result_commit["commit"] == "abcd1234abcd1234abcd1234abcd1234abcd1234"


async def test_parse_query_uuid_uniqueness() -> None:
Expand Down Expand Up @@ -280,3 +293,55 @@ async def test_parse_query_with_branch() -> None:
assert result["branch"] == "2.2.x"
assert result["commit"] is None
assert result["type"] == "blob"


@pytest.mark.asyncio
@pytest.mark.parametrize(
"url, expected_branch, expected_subpath",
[
("https://github.com/user/repo/tree/main/src", "main", "/src"),
("https://github.com/user/repo/tree/fix1", "fix1", "/"),
("https://github.com/user/repo/tree/nonexistent-branch/src", "nonexistent-branch", "/src"),
],
)
async def test_parse_repo_source_with_failed_git_command(url, expected_branch, expected_subpath):
"""
Test `_parse_repo_source` when git command fails.
Verifies that the function returns the first path component as the branch.
"""
with patch("gitingest.repository_clone.fetch_remote_branch_list", new_callable=AsyncMock) as mock_fetch_branches:
mock_fetch_branches.side_effect = Exception("Failed to fetch branch list")

result = await _parse_repo_source(url)

assert result["branch"] == expected_branch
assert result["subpath"] == expected_subpath


@pytest.mark.asyncio
@pytest.mark.parametrize(
"url, expected_branch, expected_subpath",
[
("https://github.com/user/repo/tree/feature/fix1/src", "feature/fix1", "/src"),
("https://github.com/user/repo/tree/main/src", "main", "/src"),
("https://github.com/user/repo", None, "/"), # No
("https://github.com/user/repo/tree/nonexistent-branch/src", None, "/"), # Non-existent branch
("https://github.com/user/repo/tree/fix", "fix", "/"),
("https://github.com/user/repo/blob/fix/page.html", "fix", "/page.html"),
],
)
async def test_parse_repo_source_with_various_url_patterns(url, expected_branch, expected_subpath):
with (
patch("gitingest.repository_clone._run_git_command", new_callable=AsyncMock) as mock_run_git_command,
patch("gitingest.repository_clone.fetch_remote_branch_list", new_callable=AsyncMock) as mock_fetch_branches,
):

mock_run_git_command.return_value = (
b"refs/heads/feature/fix1\nrefs/heads/main\nrefs/heads/feature-branch\nrefs/heads/fix\n",
b"",
)
mock_fetch_branches.return_value = ["feature/fix1", "main", "feature-branch"]

result = await _parse_repo_source(url)
assert result["branch"] == expected_branch
assert result["subpath"] == expected_subpath
Loading