Skip to content
This repository has been archived by the owner on Jun 21, 2023. It is now read-only.

[spike] Automatically detect the PR branch the user is on #2161

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/GitHub.App/Services/GitClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public class GitClient : IGitClient
const string defaultOriginName = "origin";
static readonly ILogger log = LogManager.ForContext<GitClient>();
readonly IGitService gitService;
readonly IGitHubCredentialProvider credentialProvider;
readonly PullOptions pullOptions;
readonly PushOptions pushOptions;
readonly FetchOptions fetchOptions;
Expand All @@ -31,6 +32,7 @@ public GitClient(IGitHubCredentialProvider credentialProvider, IGitService gitSe
Guard.ArgumentNotNull(gitService, nameof(gitService));

this.gitService = gitService;
this.credentialProvider = credentialProvider;

pushOptions = new PushOptions { CredentialsProvider = credentialProvider.HandleCredentials };
fetchOptions = new FetchOptions { CredentialsProvider = credentialProvider.HandleCredentials };
Expand Down Expand Up @@ -168,6 +170,22 @@ public Task Fetch(IRepository repository, string remoteName, params string[] ref
});
}

public Task<IDictionary<string, string>> ListReferences(IRepository repo, string remoteName)
{
return Task.Run<IDictionary<string, string>>(() =>
{
var dictionary = new Dictionary<string, string>();
var remote = repo.Network.Remotes[remoteName];
var refs = repo.Network.ListReferences(remote, credentialProvider.HandleCredentials);
foreach (var reference in refs)
{
dictionary[reference.CanonicalName] = reference.TargetIdentifier;
}

return dictionary;
});
}

public Task Checkout(IRepository repository, string branchName)
{
Guard.ArgumentNotNull(repository, nameof(repository));
Expand Down
57 changes: 55 additions & 2 deletions src/GitHub.App/Services/PullRequestService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public class PullRequestService : IPullRequestService, IStaticReviewFileMap
readonly IUsageTracker usageTracker;

readonly IDictionary<string, (string commitId, string repoPath)> tempFileMappings;

[ImportingConstructor]
public PullRequestService(
IGitClient gitClient,
Expand Down Expand Up @@ -738,11 +738,64 @@ public IObservable<Unit> SwitchToBranch(LocalRepositoryModel repository, PullReq
repo.Head.FriendlyName,
SettingGHfVSPullRequest);
var value = await gitClient.GetConfig<string>(repo, configKey);
return Observable.Return(ParseGHfVSConfigKeyValue(value));
var pr = ParseGHfVSConfigKeyValue(value);
if (pr != default((string, int)))
{
return Observable.Return(pr);
}

pr = await FindPullRequestForBranchAsync(repo, repo.Head, "origin");
return Observable.Return(pr);
}
});
}

async Task<(string owner, int number)> FindPullRequestForBranchAsync(
IRepository repo, Branch branch, string upstreamRemoteName = "origin")
{
if (!branch.IsTracking)
{
return default((string, int));
}

var remoteReferences = await gitClient.ListReferences(repo, branch.RemoteName);
if (!remoteReferences.TryGetValue(branch.UpstreamBranchCanonicalName, out var sha))
{
return default((string, int));
}

if (branch.RemoteName != upstreamRemoteName)
{
remoteReferences = await gitClient.ListReferences(repo, upstreamRemoteName);
}

var prs = remoteReferences
.Where(kv => kv.Value == sha)
.Select(kv => FindPullRequestForCanonicalName(kv.Key))
.Where(p => p != -1)
.ToList();
if (prs.Count == 0)
{
return default((string, int));
}

var owner = gitService.GetRemoteUri(repo, upstreamRemoteName).Owner;
var number = prs[0];

return (owner, number);
}

static int FindPullRequestForCanonicalName(string canonicalName)
{
var match = Regex.Match(canonicalName, "^refs/pull/([0-9]+)/head$");
if (match.Success && int.TryParse(match.Groups[1].Value, out var number))
{
return number;
}

return -1;
}

public async Task<string> ExtractToTempFile(
LocalRepositoryModel repository,
PullRequestDetailModel pullRequest,
Expand Down
3 changes: 3 additions & 0 deletions src/GitHub.Exports.Reactive/Services/IGitClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ public interface IGitClient
/// <returns></returns>
Task Fetch(IRepository repository, UriString remoteUri, params string[] refspecs);

// blar!
Task<IDictionary<string, string>> ListReferences(IRepository repo, string remoteName);

/// <summary>
/// Checks out a branch.
/// </summary>
Expand Down