Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,4 +258,5 @@ keep_body=False
remote=origin
target=main
reviewer=GithubHandle1,GithubHandle2
branch_name_template=$USERNAME/stack
```
61 changes: 46 additions & 15 deletions src/stack_pr/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import json
import os
import re
from functools import cache
from subprocess import SubprocessError

from stack_pr.git import (
Expand Down Expand Up @@ -382,12 +383,14 @@ def split_header(s: str) -> List[CommitHeader]:
return [CommitHeader(h) for h in s.split("\0")[:-1]]


def is_valid_ref(ref: str) -> bool:
def is_valid_ref(ref: str, branch_name_template: str) -> bool:
ref = ref.strip("'")
splits = ref.rsplit("/", 2)
if len(splits) < 3:

branch_name_base = get_branch_name_base(branch_name_template)
splits = ref.rsplit("/", 1)
if len(splits) < 2:
return False
return splits[-2] == "stack" and splits[-1].isnumeric()
return splits[-2].endswith(branch_name_base) and splits[-1].isnumeric()


def last(ref: str, sep: str = "/") -> str:
Expand Down Expand Up @@ -555,43 +558,57 @@ def add_or_update_metadata(
return True


def get_available_branch_name(remote: str) -> str:
@cache
def get_branch_name_base(branch_name_template: str):
username = get_gh_username()
branch_name_base = branch_name_template.replace("$USERNAME", username)
return branch_name_base


def get_available_branch_name(remote: str, branch_name_template: str) -> str:
branch_name_base = get_branch_name_base(branch_name_template)

refs = get_command_output(
[
"git",
"for-each-ref",
f"refs/remotes/{remote}/{username}/stack",
f"refs/remotes/{remote}/{branch_name_base}",
"--format='%(refname)'",
]
).split()

refs = list(filter(is_valid_ref, refs))
def check_ref(ref):
return is_valid_ref(ref, branch_name_base)

refs = list(filter(check_ref, refs))
max_ref_num = max(int(last(ref.strip("'"))) for ref in refs) if refs else 0
new_branch_id = max_ref_num + 1

return f"{username}/stack/{new_branch_id}"
return f"{branch_name_base}/{new_branch_id}"


def get_next_available_branch_name(name: str) -> str:
base, id = name.rsplit("/", 1)
return f"{base}/{int(id) + 1}"


def set_head_branches(st: List[StackEntry], remote: str, verbose: bool):
def set_head_branches(
st: List[StackEntry], remote: str, verbose: bool, branch_name_template: str
):
"""Set the head ref for each stack entry if it doesn't already have one."""

run_shell_command(["git", "fetch", "--prune", remote], quiet=not verbose)
available_name = get_available_branch_name(remote)
available_name = get_available_branch_name(remote, branch_name_template)
for e in filter(lambda e: not e.has_head(), st):
e.head = available_name
available_name = get_next_available_branch_name(available_name)


def init_local_branches(st: List[StackEntry], remote: str, verbose: bool):
def init_local_branches(
st: List[StackEntry], remote: str, verbose: bool, branch_name_template: str
):
log(h("Initializing local branches"), level=1)
set_head_branches(st, remote, verbose)
set_head_branches(st, remote, verbose, branch_name_template)
for e in st:
run_shell_command(
["git", "checkout", e.commit.commit_id(), "-B", e.head],
Expand Down Expand Up @@ -785,6 +802,7 @@ class CommonArgs(NamedTuple):
target: str
hyperlinks: bool
verbose: bool
branch_name_template: str

@classmethod
def from_args(cls, args: argparse.Namespace) -> "CommonArgs":
Expand All @@ -795,6 +813,7 @@ def from_args(cls, args: argparse.Namespace) -> "CommonArgs":
args.target,
args.hyperlinks,
args.verbose,
args.branch_name_template,
)


Expand Down Expand Up @@ -822,6 +841,7 @@ def deduce_base(args: CommonArgs) -> CommonArgs:
args.target,
args.hyperlinks,
args.verbose,
args.branch_name_template,
)


Expand Down Expand Up @@ -876,7 +896,9 @@ def command_submit(

# Create local branches and initialize base and head fields in the stack
# elements
init_local_branches(st, args.remote, args.verbose)
init_local_branches(
st, args.remote, args.verbose, args.branch_name_template
)
set_base_branches(st, args.target)
print_stack(st, args.hyperlinks)

Expand Down Expand Up @@ -1137,7 +1159,9 @@ def command_abandon(args: CommonArgs):
return
current_branch = get_current_branch_name()

init_local_branches(st, args.remote, args.verbose)
init_local_branches(
st, args.remote, args.verbose, args.branch_name_template
)
set_base_branches(st, args.target)
print_stack(st, args.hyperlinks)

Expand Down Expand Up @@ -1219,7 +1243,7 @@ def command_view(args: CommonArgs):

st = get_stack(args.base, args.head, args.verbose)

set_head_branches(st, args.remote, args.verbose)
set_head_branches(st, args.remote, args.verbose, args.branch_name_template)
set_base_branches(st, args.target)
print_stack(st, args.hyperlinks)
print_tips_after_view(st, args)
Expand Down Expand Up @@ -1268,6 +1292,13 @@ def create_argparser(
default=config.getboolean("common", "verbose", fallback=False),
help="Enable verbose output from Git subcommands.",
)
common_parser.add_argument(
"--branch-name-template",
default=config.get(
"repo", "branch_name_template", fallback="$USERNAME/stack"
),
help="A template for names of the branches stack-pr would use.",
)

parser_submit = subparsers.add_parser(
"submit",
Expand Down