Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lfs: add support for Git SSH URLs #325

Merged
merged 1 commit into from
Feb 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 99 additions & 28 deletions src/scmrepo/git/lfs/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import json
import logging
import os
import re
import shutil
from abc import abstractmethod
from collections.abc import Iterable, Iterator
from contextlib import AbstractContextManager, contextmanager, suppress
from tempfile import NamedTemporaryFile
Expand All @@ -13,6 +16,7 @@
from fsspec.implementations.http import HTTPFileSystem
from funcy import cached_property

from scmrepo.git.backend.dulwich import _get_ssh_vendor
from scmrepo.git.credentials import Credential, CredentialNotFoundError

from .exceptions import LFSError
Expand All @@ -35,19 +39,12 @@ class LFSClient(AbstractContextManager):
_SESSION_RETRIES = 5
_SESSION_BACKOFF_FACTOR = 0.1

def __init__(
self,
url: str,
git_url: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
):
def __init__(self, url: str):
"""
Args:
url: LFS server URL.
"""
self.url = url
self.git_url = git_url
self.headers: dict[str, str] = headers or {}

def __exit__(self, *args, **kwargs):
self.close()
Expand Down Expand Up @@ -84,23 +81,18 @@ def loop(self):

@classmethod
def from_git_url(cls, git_url: str) -> "LFSClient":
if git_url.endswith(".git"):
url = f"{git_url}/info/lfs"
else:
url = f"{git_url}.git/info/lfs"
return cls(url, git_url=git_url)
if git_url.startswith(("ssh://", "git@")):
return _SSHLFSClient.from_git_url(git_url)
if git_url.startswith("https://"):
return _HTTPLFSClient.from_git_url(git_url)
raise NotImplementedError(f"Unsupported Git URL: {git_url}")

def close(self):
pass

def _get_auth(self) -> Optional[aiohttp.BasicAuth]:
try:
creds = Credential(url=self.git_url).fill()
if creds.username and creds.password:
return aiohttp.BasicAuth(creds.username, creds.password)
except CredentialNotFoundError:
pass
return None
@abstractmethod
def _get_auth_header(self, *, upload: bool) -> dict:
...

async def _batch_request(
self,
Expand All @@ -120,9 +112,10 @@ async def _batch_request(
if ref:
body["ref"] = [{"name": ref}]
session = await self._fs.set_session()
headers = dict(self.headers)
headers["Accept"] = self.JSON_CONTENT_TYPE
headers["Content-Type"] = self.JSON_CONTENT_TYPE
headers = {
"Accept": self.JSON_CONTENT_TYPE,
"Content-Type": self.JSON_CONTENT_TYPE,
}
try:
async with session.post(
url,
Expand All @@ -134,13 +127,12 @@ async def _batch_request(
except aiohttp.ClientResponseError as exc:
if exc.status != 401:
raise
auth = self._get_auth()
if auth is None:
auth_header = self._get_auth_header(upload=upload)
if not auth_header:
raise
async with session.post(
url,
auth=auth,
headers=headers,
headers={**headers, **auth_header},
json=body,
raise_for_status=True,
) as resp:
Expand Down Expand Up @@ -186,6 +178,85 @@ async def _get_one(from_path: str, to_path: str, **kwargs):
download = sync_wrapper(_download)


class _HTTPLFSClient(LFSClient):
def __init__(self, url: str, git_url: str):
"""
Args:
url: LFS server URL.
git_url: Git HTTP URL.
"""
super().__init__(url)
self.git_url = git_url

@classmethod
def from_git_url(cls, git_url: str) -> "_HTTPLFSClient":
if git_url.endswith(".git"):
url = f"{git_url}/info/lfs"
else:
url = f"{git_url}.git/info/lfs"
return cls(url, git_url=git_url)

def _get_auth_header(self, *, upload: bool) -> dict:
try:
creds = Credential(url=self.git_url).fill()
if creds.username and creds.password:
return {
aiohttp.hdrs.AUTHORIZATION: aiohttp.BasicAuth(
creds.username, creds.password
).encode()
}
except CredentialNotFoundError:
pass
return {}


class _SSHLFSClient(LFSClient):
_URL_PATTERN = re.compile(
r"(?:ssh://)?git@(?P<host>\S+?)(?::(?P<port>\d+))?(?:[:/])(?P<path>\S+?)\.git"
)

def __init__(self, url: str, host: str, port: int, path: str):
"""
Args:
url: LFS server URL.
host: Git SSH server host.
port: Git SSH server port.
path: Git project path.
"""
super().__init__(url)
self.host = host
self.port = port
self.path = path
self._ssh = _get_ssh_vendor()

@classmethod
def from_git_url(cls, git_url: str) -> "_SSHLFSClient":
result = cls._URL_PATTERN.match(git_url)
if not result:
raise ValueError(f"Invalid Git SSH URL: {git_url}")
host, port, path = result.group("host", "port", "path")
url = f"https://{host}/{path}.git/info/lfs"
return cls(url, host, int(port or 22), path)

def _get_auth_header(self, *, upload: bool) -> dict:
return self._git_lfs_authenticate(
self.host, self.port, f"{self.path}.git", upload=upload
).get("header", {})

def _git_lfs_authenticate(
self, host: str, port: int, path: str, *, upload: bool = False
) -> dict:
action = "upload" if upload else "download"
return json.loads(
self._ssh.run_command(
command=f"git-lfs-authenticate {path} {action}",
host=host,
port=port,
username="git",
).read()
)


@contextmanager
def _as_atomic(to_info: str, create_parents: bool = False) -> Iterator[str]:
parent = os.path.dirname(to_info)
Expand Down
Loading