Skip to content

Commit

Permalink
make download tests pass again
Browse files Browse the repository at this point in the history
  • Loading branch information
mashehu committed Dec 14, 2023
1 parent ecc93be commit 5aaaa2b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 21 deletions.
6 changes: 6 additions & 0 deletions nf_core/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,9 @@ def __init__(

self.setup_local_repo(remote=remote_url, location=location, in_cache=in_cache)

# expose some instance attributes
self.tags = self.repo.tags

def __repr__(self):
"""Called by print, creates representation of object"""
return f"<Locally cached repository: {self.fullname}, revisions {', '.join(self.revision)}\n cached at: {self.local_repo_dir}>"
Expand Down Expand Up @@ -1495,6 +1498,7 @@ def tidy_tags_and_branches(self):
# delete unwanted tags from repository
for tag in tags_to_remove:
self.repo.delete_tag(tag)
self.tags = self.repo.tags

# switch to a revision that should be kept, because deleting heads fails, if they are checked out (e.g. "master")
self.checkout(self.revision[0])
Expand Down Expand Up @@ -1531,6 +1535,8 @@ def tidy_tags_and_branches(self):
if self.repo.head.is_detached:
self.repo.head.reset(index=True, working_tree=True)

self.heads = self.repo.heads

# get all tags and available remote_branches
completed_revisions = {revision.name for revision in self.repo.heads + self.repo.tags}

Expand Down
41 changes: 20 additions & 21 deletions nf_core/synced_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import os
import shutil
from pathlib import Path
from typing import Dict, Union
from typing import Dict

import git
from git.cmd import Git
import rich.progress
from git.exc import GitCommandError
from git.repo import Repo

from nf_core.utils import load_tools_config

Expand Down Expand Up @@ -45,19 +44,19 @@ def __init__(self, progress_bar, repo_name, remote_url, operation):
state="Waiting for response",
)

def update(self, op_code, cur_count, max_count, message=""):
def update(self, op_code, cur_count, max_count=None, message=""):
"""
Overrides git.RemoteProgress.update.
Called every time there is a change in the remote operation
"""
if not self.progress_bar.tasks[self.tid].started:
self.progress_bar.start_task(self.tid)
self.progress_bar.update(
self.tid, total=max_count, completed=cur_count, state=f"{float(cur_count) / float(max_count) * 100:.1f}%"
self.tid, total=max_count, completed=cur_count, state=f"{cur_count / max_count * 100:.1f}%"
)


class SyncedRepo(Repo):
class SyncedRepo:
"""
An object to store details about a locally cached code repository.
"""
Expand Down Expand Up @@ -91,7 +90,7 @@ def get_remote_branches(remote_url):
(set[str]): All branches found in the remote
"""
try:
unparsed_branches = Git().ls_remote(remote_url)
unparsed_branches = git.Git().ls_remote(remote_url)
except git.GitCommandError:
raise LookupError(f"Was unable to fetch branches from '{remote_url}'")
else:
Expand Down Expand Up @@ -175,11 +174,8 @@ def setup_branch(self, branch):
else:
self.branch = branch

# Verify that the branch exists using git
try:
self.checkout_branch()
except git.GitCommandError:
raise LookupError(f"Branch '{self.branch}' not found in '{self.remote_url}'")
# Verify that the branch exists by checking it out
self.branch_exists()

def get_default_branch(self):
"""
Expand All @@ -189,6 +185,15 @@ def get_default_branch(self):
_, branch = origin_head.ref.name.split("/")
return branch

def branch_exists(self):
"""
Verifies that the branch exists in the repository by trying to check it out
"""
try:
self.checkout_branch()
except GitCommandError:
raise LookupError(f"Branch '{self.branch}' not found in '{self.remote_url}'")

def verify_branch(self):
"""
Verifies the active branch conforms to the correct directory structure
Expand All @@ -206,9 +211,7 @@ def checkout_branch(self):
"""
Checks out the specified branch of the repository
"""
# only checkout if we're on a detached head or if we're not already on the branch
if self.repo.head.is_detached or self.repo.active_branch.name != self.branch:
self.repo.git.checkout(self.branch)
self.repo.git.checkout(self.branch)

def checkout(self, commit):
"""
Expand All @@ -217,9 +220,7 @@ def checkout(self, commit):
Args:
commit (str): Git SHA of the commit
"""
# only checkout if we are not already on the commit
if self.repo.head.commit.hexsha != commit:
self.repo.git.checkout(commit)
self.repo.git.checkout(commit)

def component_exists(self, component_name, component_type, checkout=True, commit=None):
"""
Expand Down Expand Up @@ -248,9 +249,7 @@ def get_component_dir(self, component_name, component_type):
elif component_type == "subworkflows":
return os.path.join(self.subworkflows_dir, component_name)

def install_component(
self, component_name: Union[str, Path], install_dir: str, commit: str, component_type: str
) -> bool:
def install_component(self, component_name, install_dir, commit, component_type):
"""
Install the module/subworkflow files into a pipeline at the given commit
Expand Down

0 comments on commit 5aaaa2b

Please sign in to comment.