diff --git a/news/2037.bugfix b/news/2037.bugfix new file mode 100644 index 00000000000..aca18b07b5e --- /dev/null +++ b/news/2037.bugfix @@ -0,0 +1 @@ +Checkout the correct branch when doing an editable Git install. \ No newline at end of file diff --git a/news/88E800D4-2360-48F6-BD1D-9C6BAB0D059E.trivial b/news/88E800D4-2360-48F6-BD1D-9C6BAB0D059E.trivial new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/pip/_internal/vcs/git.py b/src/pip/_internal/vcs/git.py index bacc037f7ac..7169dc700e3 100644 --- a/src/pip/_internal/vcs/git.py +++ b/src/pip/_internal/vcs/git.py @@ -77,6 +77,20 @@ def get_git_version(self): version = '.'.join(version.split('.')[:3]) return parse_version(version) + def get_branch(self, location): + """ + Return the current branch, or None if HEAD isn't at a branch + (e.g. detached HEAD). + """ + args = ['rev-parse', '--abbrev-ref', 'HEAD'] + output = self.run_command(args, show_stdout=False, cwd=location) + branch = output.strip() + + if branch == 'HEAD': + return None + + return branch + def export(self, location): """Export the Git repository at the url to the destination location""" if not location.endswith('/'): @@ -91,8 +105,8 @@ def export(self, location): def get_revision_sha(self, dest, rev): """ - Return a commit hash for the given revision if it names a remote - branch or tag. Otherwise, return None. + Return (sha_or_none, is_branch), where sha_or_none is a commit hash + if the revision names a remote branch or tag, otherwise None. Args: dest: the repository directory. @@ -115,7 +129,13 @@ def get_revision_sha(self, dest, rev): branch_ref = 'refs/remotes/origin/{}'.format(rev) tag_ref = 'refs/tags/{}'.format(rev) - return refs.get(branch_ref) or refs.get(tag_ref) + sha = refs.get(branch_ref) + if sha is not None: + return (sha, True) + + sha = refs.get(tag_ref) + + return (sha, False) def resolve_revision(self, dest, url, rev_options): """ @@ -126,10 +146,13 @@ def resolve_revision(self, dest, url, rev_options): rev_options: a RevOptions object. """ rev = rev_options.arg_rev - sha = self.get_revision_sha(dest, rev) + sha, is_branch = self.get_revision_sha(dest, rev) if sha is not None: - return rev_options.make_new(sha) + rev_options = rev_options.make_new(sha) + rev_options.branch_name = rev if is_branch else None + + return rev_options # Do not show a warning for the common case of something that has # the form of a Git commit hash. @@ -177,10 +200,20 @@ def fetch_new(self, dest, url, rev_options): if rev_options.rev: # Then a specific revision was requested. rev_options = self.resolve_revision(dest, url, rev_options) - # Only do a checkout if the current commit id doesn't match - # the requested revision. - if not self.is_commit_id_equal(dest, rev_options.rev): - cmd_args = ['checkout', '-q'] + rev_options.to_args() + branch_name = getattr(rev_options, 'branch_name', None) + if branch_name is None: + # Only do a checkout if the current commit id doesn't match + # the requested revision. + if not self.is_commit_id_equal(dest, rev_options.rev): + cmd_args = ['checkout', '-q'] + rev_options.to_args() + self.run_command(cmd_args, cwd=dest) + elif self.get_branch(dest) != branch_name: + # Then a specific branch was requested, and that branch + # is not yet checked out. + track_branch = 'origin/{}'.format(branch_name) + cmd_args = [ + 'checkout', '-b', branch_name, '--track', track_branch, + ] self.run_command(cmd_args, cwd=dest) #: repo may contain submodules diff --git a/tests/functional/test_install_vcs_git.py b/tests/functional/test_install_vcs_git.py index 6648e44fef4..77296baf10e 100644 --- a/tests/functional/test_install_vcs_git.py +++ b/tests/functional/test_install_vcs_git.py @@ -10,15 +10,32 @@ from tests.lib.local_repos import local_checkout +def _get_editable_repo_dir(script, package_name): + """ + Return the repository directory for an editable install. + """ + return script.venv_path / 'src' / package_name + + def _get_editable_branch(script, package_name): """ Return the current branch of an editable install. """ - repo_dir = script.venv_path / 'src' / package_name + repo_dir = _get_editable_repo_dir(script, package_name) result = script.run( 'git', 'rev-parse', '--abbrev-ref', 'HEAD', cwd=repo_dir ) + return result.stdout.strip() + +def _get_branch_remote(script, package_name, branch): + """ + + """ + repo_dir = _get_editable_repo_dir(script, package_name) + result = script.run( + 'git', 'config', 'branch.{}.remote'.format(branch), cwd=repo_dir + ) return result.stdout.strip() @@ -363,7 +380,69 @@ def test_git_works_with_editable_non_origin_repo(script): assert "version-pkg==0.1" in result.stdout -def test_editable_non_master_default_branch(script): +def test_editable__no_revision(script): + """ + Test a basic install in editable mode specifying no revision. + """ + version_pkg_path = _create_test_package(script) + _install_version_pkg_only(script, version_pkg_path) + + branch = _get_editable_branch(script, 'version-pkg') + assert branch == 'master' + + remote = _get_branch_remote(script, 'version-pkg', 'master') + assert remote == 'origin' + + +def test_editable__branch_with_sha_same_as_default(script): + """ + Test installing in editable mode a branch whose sha matches the sha + of the default branch, but is different from the default branch. + """ + version_pkg_path = _create_test_package(script) + # Create a second branch with the same SHA. + script.run( + 'git', 'branch', 'develop', expect_stderr=True, + cwd=version_pkg_path, + ) + _install_version_pkg_only( + script, version_pkg_path, rev='develop', expect_stderr=True + ) + + branch = _get_editable_branch(script, 'version-pkg') + assert branch == 'develop' + + remote = _get_branch_remote(script, 'version-pkg', 'develop') + assert remote == 'origin' + + +def test_editable__branch_with_sha_different_from_default(script): + """ + Test installing in editable mode a branch whose sha is different from + the sha of the default branch. + """ + version_pkg_path = _create_test_package(script) + # Create a second branch. + script.run( + 'git', 'branch', 'develop', expect_stderr=True, + cwd=version_pkg_path, + ) + # Add another commit to the master branch to give it a different sha. + _change_test_package_version(script, version_pkg_path) + + version = _install_version_pkg( + script, version_pkg_path, rev='develop', expect_stderr=True + ) + assert version == '0.1' + + branch = _get_editable_branch(script, 'version-pkg') + assert branch == 'develop' + + remote = _get_branch_remote(script, 'version-pkg', 'develop') + assert remote == 'origin' + + +def test_editable__non_master_default_branch(script): """ Test the branch you get after an editable install from a remote repo with a non-master default branch. @@ -376,8 +455,9 @@ def test_editable_non_master_default_branch(script): cwd=version_pkg_path, ) _install_version_pkg_only(script, version_pkg_path) + branch = _get_editable_branch(script, 'version-pkg') - assert 'release' == branch + assert branch == 'release' def test_reinstalling_works_with_editable_non_master_branch(script): diff --git a/tests/functional/test_vcs_git.py b/tests/functional/test_vcs_git.py index 656cc33ed96..2a54af58fc8 100644 --- a/tests/functional/test_vcs_git.py +++ b/tests/functional/test_vcs_git.py @@ -37,9 +37,9 @@ def add_commits(script, dest, count): return shas -def check_rev(repo_dir, rev, expected_sha): +def check_rev(repo_dir, rev, expected): git = Git() - assert git.get_revision_sha(repo_dir, rev) == expected_sha + assert git.get_revision_sha(repo_dir, rev) == expected def test_git_dir_ignored(): @@ -70,6 +70,27 @@ def test_git_work_tree_ignored(): git.run_command(['status', temp_dir], extra_environ=env, cwd=temp_dir) +def test_get_branch(script, tmpdir): + repo_dir = str(tmpdir) + script.run('git', 'init', cwd=repo_dir) + sha = do_commit(script, repo_dir) + + git = Git() + assert git.get_branch(repo_dir) == 'master' + + # Switch to a branch with the same SHA as "master" but whose name + # is alphabetically after. + script.run( + 'git', 'checkout', '-b', 'release', cwd=repo_dir, + expect_stderr=True, + ) + assert git.get_branch(repo_dir) == 'release' + + # Also test the detached HEAD case. + script.run('git', 'checkout', sha, cwd=repo_dir, expect_stderr=True) + assert git.get_branch(repo_dir) is None + + def test_get_revision_sha(script): with TempDirectory(kind="testing") as temp: repo_dir = temp.path @@ -102,9 +123,9 @@ def test_get_revision_sha(script): script.run('git', 'tag', 'aaa/v1.0', head_sha, cwd=repo_dir) script.run('git', 'tag', 'zzz/v1.0', head_sha, cwd=repo_dir) - check_rev(repo_dir, 'v1.0', tag_sha) - check_rev(repo_dir, 'v2.0', tag_sha) - check_rev(repo_dir, 'origin-branch', origin_sha) + check_rev(repo_dir, 'v1.0', (tag_sha, False)) + check_rev(repo_dir, 'v2.0', (tag_sha, False)) + check_rev(repo_dir, 'origin-branch', (origin_sha, True)) ignored_names = [ # Local branches should be ignored. @@ -122,7 +143,7 @@ def test_get_revision_sha(script): 'does-not-exist', ] for name in ignored_names: - check_rev(repo_dir, name, None) + check_rev(repo_dir, name, (None, False)) @pytest.mark.network diff --git a/tests/unit/test_vcs.py b/tests/unit/test_vcs.py index c9ad863cc4c..7e8934c7642 100644 --- a/tests/unit/test_vcs.py +++ b/tests/unit/test_vcs.py @@ -109,7 +109,7 @@ def test_git_get_src_requirements(git, dist): @patch('pip._internal.vcs.git.Git.get_revision_sha') def test_git_resolve_revision_rev_exists(get_sha_mock): - get_sha_mock.return_value = '123456' + get_sha_mock.return_value = ('123456', False) git = Git() rev_options = git.make_rev_options('develop') @@ -120,7 +120,7 @@ def test_git_resolve_revision_rev_exists(get_sha_mock): @patch('pip._internal.vcs.git.Git.get_revision_sha') def test_git_resolve_revision_rev_not_found(get_sha_mock): - get_sha_mock.return_value = None + get_sha_mock.return_value = (None, False) git = Git() rev_options = git.make_rev_options('develop') @@ -131,7 +131,7 @@ def test_git_resolve_revision_rev_not_found(get_sha_mock): @patch('pip._internal.vcs.git.Git.get_revision_sha') def test_git_resolve_revision_not_found_warning(get_sha_mock, caplog): - get_sha_mock.return_value = None + get_sha_mock.return_value = (None, False) git = Git() url = 'git+https://git.example.com'