diff --git a/README.md b/README.md index 6cff305..03468f4 100644 --- a/README.md +++ b/README.md @@ -250,7 +250,10 @@ These arguments can be used with any subcommand: - `-T, --target`: Remote target branch (default: "main") - `--hyperlinks/--no-hyperlinks`: Enable/disable hyperlink support (default: enabled) - `-V, --verbose`: Enable verbose output from Git subcommands (default: false) -- `--branch-name-template`: Template for generated branch names (default: "$USERNAME/stack") +- `--branch-name-template`: Template for generated branch names (default: "$USERNAME/stack"). The following variables are supported: + - `$USERNAME`: The username of the current user + - `$BRANCH`: The current branch name + - `$ID`: The location for the ID of the branch. The ID is determined by the order of creation of the branches. If `$ID` is not found in the template, the template will be appended with `/$ID`. ### Subcommands diff --git a/src/stack_pr/cli.py b/src/stack_pr/cli.py index 12170d3..45a9c3d 100755 --- a/src/stack_pr/cli.py +++ b/src/stack_pr/cli.py @@ -286,7 +286,7 @@ def pprint(self, links: bool): s = b(self.commit.commit_id()[:8]) pr_string = None if self.has_pr(): - pr_string = blue("#" + self.pr.split("/")[-1]) + pr_string = blue("#" + last(self.pr)) else: pr_string = red("no PR") branch_string = None @@ -384,16 +384,6 @@ def split_header(s: str) -> List[CommitHeader]: return [CommitHeader(h) for h in s.split("\0")[:-1]] -def is_valid_ref(ref: str, branch_name_template: str) -> bool: - ref = ref.strip("'") - - branch_name_base = get_branch_name_base(branch_name_template) - splits = ref.rsplit("/", 1) - if len(splits) < 2: - return False - return splits[-2].endswith(branch_name_base) and splits[-1].isnumeric() - - def last(ref: str, sep: str = "/") -> str: return ref.rsplit(sep, 1)[-1] @@ -559,6 +549,13 @@ def add_or_update_metadata(e: StackEntry, needs_rebase: bool, verbose: bool) -> return True +def fix_branch_name_template(branch_name_template: str): + if "$ID" not in branch_name_template: + return f"{branch_name_template}/$ID" + + return branch_name_template + + @cache def get_branch_name_base(branch_name_template: str): username = get_gh_username() @@ -568,31 +565,54 @@ def get_branch_name_base(branch_name_template: str): return branch_name_base +def get_branch_id(branch_name_template: str, branch_name: str): + branch_name_base = get_branch_name_base(branch_name_template) + pattern = branch_name_base.replace(r"$ID", r"(\d+)") + match = re.search(pattern, branch_name) + if match: + return match.group(1) + return None + + +def generate_branch_name(branch_name_template: str, branch_id: int): + branch_name_base = get_branch_name_base(branch_name_template) + branch_name = branch_name_base.replace(r"$ID", branch_id) + return branch_name + + +def get_taken_branch_ids(refs: List[str], branch_name_template: str) -> List[int]: + branch_ids = list(get_branch_id(branch_name_template, ref) for ref in refs) + branch_ids = [int(branch_id) for branch_id in branch_ids if branch_id is not None] + return branch_ids + + +def generate_available_branch_name(refs: List[str], branch_name_template: str) -> str: + branch_ids = get_taken_branch_ids(refs, branch_name_template) + max_ref_num = max(branch_ids) if branch_ids else 0 + new_branch_id = max_ref_num + 1 + return generate_branch_name(branch_name_template, str(new_branch_id)) + + def get_available_branch_name(remote: str, branch_name_template: str) -> str: branch_name_base = get_branch_name_base(branch_name_template) + git_command_branch_template = branch_name_base.replace(r"$ID", "*") refs = get_command_output( [ "git", "for-each-ref", - f"refs/remotes/{remote}/{branch_name_base}", + f"refs/remotes/{remote}/{git_command_branch_template}", "--format='%(refname)'", ] ).split() - def check_ref(ref): - return is_valid_ref(ref, branch_name_base) - - refs = list(filter(check_ref, refs)) - max_ref_num = max(int(last(ref.strip("'"))) for ref in refs) if refs else 0 - new_branch_id = max_ref_num + 1 + refs = list([ref.strip("'") for ref in refs]) + return generate_available_branch_name(refs, branch_name_template) - return f"{branch_name_base}/{new_branch_id}" - -def get_next_available_branch_name(name: str) -> str: - base, id = name.rsplit("/", 1) - return f"{base}/{int(id) + 1}" +def get_next_available_branch_name(branch_name_template: str, name: str) -> str: + id = get_branch_id(branch_name_template, name) + return generate_branch_name(branch_name_template, str(int(id) + 1)) def set_head_branches( @@ -604,7 +624,9 @@ def set_head_branches( available_name = get_available_branch_name(remote, branch_name_template) for e in filter(lambda e: not e.has_head(), st): e.head = available_name - available_name = get_next_available_branch_name(available_name) + available_name = get_next_available_branch_name( + branch_name_template, available_name + ) def init_local_branches( @@ -1359,6 +1381,8 @@ def main(): parser.print_help() return + # Make sure "$ID" is present in the branch name template and append it if not + args.branch_name_template = fix_branch_name_template(args.branch_name_template) common_args = CommonArgs.from_args(args) if common_args.verbose: diff --git a/tests/test_misc.py b/tests/test_misc.py new file mode 100644 index 0000000..1587a7d --- /dev/null +++ b/tests/test_misc.py @@ -0,0 +1,83 @@ +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent / "src")) + +from stack_pr.cli import ( + get_branch_id, + generate_branch_name, + get_taken_branch_ids, + get_gh_username, + generate_available_branch_name, +) + +import pytest + + +@pytest.fixture(scope="module") +def username(): + return get_gh_username() + + +@pytest.mark.parametrize( + "template,branch_name,expected", + [ + ("feature-$ID-desc", "feature-123-desc", "123"), + ("$USERNAME/stack/$ID", "{username}/stack/99", "99"), + ("$USERNAME/stack/$ID", "refs/remote/origin/{username}/stack/99", "99"), + ], +) +def test_get_branch_id(username, template, branch_name, expected): + branch_name = branch_name.format(username=username) + assert get_branch_id(template, branch_name) == expected + + +@pytest.mark.parametrize( + "template,branch_name", + [ + ("feature/$ID/desc", "feature/abc/desc"), + ("feature/$ID/desc", "wrong/format"), + ("$USERNAME/stack/$ID", "{username}/main/99"), + ], +) +def test_get_branch_id_no_match(username, template, branch_name): + branch_name = branch_name.format(username=username) + assert get_branch_id(template, branch_name) is None + + +def test_generate_branch_name(): + template = "feature/$ID/description" + assert generate_branch_name(template, "123") == "feature/123/description" + + +def test_get_taken_branch_ids(): + template = "User/stack/$ID" + refs = [ + "refs/remotes/origin/User/stack/104", + "refs/remotes/origin/User/stack/105", + "refs/remotes/origin/User/stack/134", + ] + assert get_taken_branch_ids(refs, template) == [104, 105, 134] + refs = ["User/stack/104", "User/stack/105", "User/stack/134"] + assert get_taken_branch_ids(refs, template) == [104, 105, 134] + refs = ["User/stack/104", "AAAA/stack/105", "User/stack/134", "User/stack/bbb"] + assert get_taken_branch_ids(refs, template) == [104, 134] + + +def test_generate_available_branch_name(): + template = "User/stack/$ID" + refs = [ + "refs/remotes/origin/User/stack/104", + "refs/remotes/origin/User/stack/105", + "refs/remotes/origin/User/stack/134", + ] + assert generate_available_branch_name(refs, template) == "User/stack/135" + refs = [] + assert generate_available_branch_name(refs, template) == "User/stack/1" + template = "User-stack-$ID" + refs = [ + "refs/remotes/origin/User-stack-104", + "refs/remotes/origin/User-stack-105", + "refs/remotes/origin/User-stack-134", + ] + assert generate_available_branch_name(refs, template) == "User-stack-135"