Skip to content

Commit

Permalink
fix: inplace update (#2427)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostming authored Nov 23, 2023
1 parent 837e7d0 commit 62ceee3
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 36 deletions.
1 change: 1 addition & 0 deletions news/2423.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Updating package now overwrites the old files instead of removing before installing.
9 changes: 4 additions & 5 deletions src/pdm/installers/installers.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _symlink_files(symlink_to: str) -> Iterator[tuple[Scheme, RecordEntry]]:
return super().finalize_installation(scheme, record_file_path, records)


def install_wheel(wheel: str, environment: BaseEnvironment, direct_url: dict[str, Any] | None = None) -> None:
def install_wheel(wheel: str, environment: BaseEnvironment, direct_url: dict[str, Any] | None = None) -> str:
"""Install a normal wheel file into the environment."""
additional_metadata = None
if direct_url is not None:
Expand All @@ -178,12 +178,10 @@ def install_wheel(wheel: str, environment: BaseEnvironment, direct_url: dict[str
interpreter=str(environment.interpreter.executable),
script_kind=_get_kind(environment),
)
_install_wheel(wheel=wheel, destination=destination, additional_metadata=additional_metadata)
return _install_wheel(wheel=wheel, destination=destination, additional_metadata=additional_metadata)


def install_wheel_with_cache(
wheel: str, environment: BaseEnvironment, direct_url: dict[str, Any] | None = None
) -> None:
def install_wheel_with_cache(wheel: str, environment: BaseEnvironment, direct_url: dict[str, Any] | None = None) -> str:
"""Only create .pth files referring to the cached package.
If the cache doesn't exist, create one.
"""
Expand Down Expand Up @@ -245,6 +243,7 @@ def skip_files(source: WheelFile, element: WheelContentElement) -> bool:
additional_metadata=additional_metadata,
)
package_cache.add_referrer(dist_info_dir)
return dist_info_dir


def _install_wheel(
Expand Down
23 changes: 20 additions & 3 deletions src/pdm/installers/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from typing import TYPE_CHECKING

from pdm import termui
from pdm.compat import Distribution
from pdm.exceptions import UninstallError
from pdm.installers.installers import install_wheel, install_wheel_with_cache
from pdm.installers.uninstallers import BaseRemovePaths, StashedRemovePaths

if TYPE_CHECKING:
from pdm.compat import Distribution
from pdm.environments import BaseEnvironment
from pdm.models.candidates import Candidate

Expand All @@ -23,14 +23,16 @@ def __init__(self, environment: BaseEnvironment, *, use_install_cache: bool = Fa
self.environment = environment
self.use_install_cache = use_install_cache

def install(self, candidate: Candidate) -> None:
def install(self, candidate: Candidate) -> Distribution:
"""Install a candidate into the environment, return the distribution"""
if self.use_install_cache and candidate.req.is_named and candidate.name not in self.NO_CACHE_PACKAGES:
# Only cache wheels from PyPI
installer = install_wheel_with_cache
else:
installer = install_wheel
prepared = candidate.prepare(self.environment)
installer(str(prepared.build()), self.environment, prepared.direct_url())
dist_info = installer(str(prepared.build()), self.environment, prepared.direct_url())
return Distribution.at(dist_info)

def get_paths_to_remove(self, dist: Distribution) -> BaseRemovePaths:
"""Get the path collection to be removed from the disk"""
Expand All @@ -48,3 +50,18 @@ def uninstall(self, dist: Distribution) -> None:
termui.logger.info("Error occurred during uninstallation, roll back the changes now.")
remove_path.rollback()
raise UninstallError(e) from e

def overwrite(self, dist: Distribution, candidate: Candidate) -> None:
"""An in-place update to overwrite the distribution with a new candidate"""
paths_to_remove = self.get_paths_to_remove(dist)
termui.logger.info("Overwriting distribution %s", dist.metadata["Name"])
installed = self.install(candidate)
installed_paths = self.get_paths_to_remove(installed)
# Remove the paths that are in the new distribution
paths_to_remove.difference_update(installed_paths)
try:
paths_to_remove.remove()
paths_to_remove.commit()
except OSError as e:
termui.logger.info("Error occurred during overwriting, roll back the changes now.")
raise UninstallError(e) from e
3 changes: 1 addition & 2 deletions src/pdm/installers/synchronizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,7 @@ def update_candidate(self, key: str, progress: Progress) -> tuple[Distribution,
)
can.prepare(self.environment, RichProgressReporter(progress, job))
try:
self.manager.uninstall(dist)
self.manager.install(can)
self.manager.overwrite(dist, can)
except Exception:
progress.live.console.print(
f" [error]{termui.Emoji.FAIL}[/] Update [req]{key}[/] "
Expand Down
48 changes: 33 additions & 15 deletions src/pdm/installers/uninstallers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Iterable, TypeVar, cast
from typing import TYPE_CHECKING, Iterable, NewType, TypeVar, cast

from pdm import termui
from pdm.exceptions import UninstallError
Expand All @@ -18,6 +18,7 @@
from pdm.environments import BaseEnvironment

_T = TypeVar("_T", bound="BaseRemovePaths")
NormalizedPath = NewType("NormalizedPath", str)


def renames(old: str, new: str) -> None:
Expand All @@ -37,26 +38,26 @@ def renames(old: str, new: str) -> None:
pass


def compress_for_rename(paths: Iterable[str]) -> set[str]:
def compress_for_rename(paths: Iterable[NormalizedPath]) -> set[NormalizedPath]:
"""Returns a set containing the paths that need to be renamed.
This set may include directories when the original sequence of paths
included every file on disk.
"""
case_map = {os.path.normcase(p): p for p in paths if os.path.exists(p)}
case_map = {NormalizedPath(os.path.normcase(p)): p for p in paths if os.path.exists(p)}
remaining = set(case_map)
unchecked = sorted({os.path.split(p)[0] for p in case_map.values()}, key=len)
wildcards: set[str] = set()
unchecked = sorted({NormalizedPath(os.path.split(p)[0]) for p in case_map.values()}, key=len)
wildcards: set[NormalizedPath] = set()

def norm_join(*a: str) -> str:
return os.path.normcase(os.path.join(*a))
def norm_join(*a: str) -> NormalizedPath:
return NormalizedPath(os.path.normcase(os.path.join(*a)))

for root in unchecked:
if any(os.path.normcase(root).startswith(w) for w in wildcards):
# This directory has already been handled.
continue

all_files: set[str] = set()
all_files: set[NormalizedPath] = set()
for dirname, subdirs, files in os.walk(root):
all_files.update(norm_join(root, dirname, f) for f in files)
for d in subdirs:
Expand All @@ -69,10 +70,10 @@ def norm_join(*a: str) -> str:
# for the directory.
if not (all_files - remaining):
remaining.difference_update(all_files)
wildcards.add(root + os.sep)
wildcards.add(NormalizedPath(root + os.sep))

collected = set(map(case_map.__getitem__, remaining)) | wildcards
shortened: set[str] = set()
shortened: set[NormalizedPath] = set()
# Filter out any paths that are sub paths of another path in the path collection.
for path in sorted(collected, key=len):
if not any(is_path_relative_to(path, p) for p in shortened):
Expand All @@ -91,13 +92,13 @@ def _script_names(script_name: str, is_gui: bool) -> Iterable[str]:
yield script_name + "-script.py"


def _cache_file_from_source(py_file: str) -> Iterable[str]:
def _cache_file_from_source(py_file: NormalizedPath) -> Iterable[NormalizedPath]:
py2_cache = py_file[:-3] + ".pyc"
if os.path.isfile(py2_cache):
yield py2_cache
yield NormalizedPath(py2_cache)
parent, base = os.path.split(py_file)
cache_dir = os.path.join(parent, "__pycache__")
yield from glob.glob(os.path.join(cache_dir, base[:-3] + ".*.pyc"))
yield from map(NormalizedPath, glob.glob(os.path.join(cache_dir, base[:-3] + ".*.pyc")))


def _get_file_root(path: str, base: str) -> str | None:
Expand All @@ -110,16 +111,33 @@ def _get_file_root(path: str, base: str) -> str | None:
return os.path.normcase(os.path.join(base, root))


def _get_all_parents(path: NormalizedPath) -> Iterable[NormalizedPath]:
while True:
yield path
parent = NormalizedPath(os.path.split(path)[0])
if parent == path:
break
path = parent


class BaseRemovePaths(abc.ABC):
"""A collection of paths and/or pth entries to remove"""

def __init__(self, dist: Distribution, environment: BaseEnvironment) -> None:
self.dist = dist
self.environment = environment
self._paths: set[str] = set()
self._paths: set[NormalizedPath] = set()
self._pth_entries: set[str] = set()
self.refer_to: str | None = None

def difference_update(self, other: BaseRemovePaths) -> None:
self._pth_entries.difference_update(other._pth_entries)
for p in other._paths:
# if other_p is a file, remove all parent dirs of it
self._paths.difference_update(_get_all_parents(p))
# other_p is a symlink dir, remove all files under it
self._paths.difference_update({p2 for p2 in self._paths if p2.startswith(p + os.sep)})

@abc.abstractmethod
def remove(self) -> None:
"""Remove the files"""
Expand Down Expand Up @@ -192,7 +210,7 @@ def add_pth(self, line: str) -> None:
self._pth_entries.add(line)

def add_path(self, path: str) -> None:
normalized_path = os.path.normcase(os.path.expanduser(os.path.abspath(path)))
normalized_path = NormalizedPath(os.path.normcase(os.path.expanduser(os.path.abspath(path))))
self._paths.add(normalized_path)
if path.endswith(".py"):
self._paths.update(_cache_file_from_source(normalized_path))
Expand Down
27 changes: 16 additions & 11 deletions src/pdm/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,22 +452,27 @@ def working_set(mocker: MockerFixture, repository: TestRepository) -> MockWorkin
Returns:
a mock working set
"""
from pdm.installers import InstallManager

rv = MockWorkingSet()
mocker.patch.object(BaseEnvironment, "get_working_set", return_value=rv)

def install(candidate: Candidate) -> None:
key = normalize_name(candidate.name or "")
dist = Distribution(key, cast(str, candidate.version), candidate.req.editable)
dist.dependencies = repository.get_raw_dependencies(candidate)
rv.add_distribution(dist)
class MockInstallManager(InstallManager):
def install(self, candidate: Candidate) -> Distribution: # type: ignore[override]
key = normalize_name(candidate.name or "")
dist = Distribution(key, cast(str, candidate.version), candidate.req.editable)
dist.dependencies = repository.get_raw_dependencies(candidate)
rv.add_distribution(dist)
return dist

def uninstall(self, dist: Distribution) -> None: # type: ignore[override]
del rv[dist.name]

def uninstall(dist: Distribution) -> None:
del rv[dist.name]
def overwrite(self, dist: Distribution, candidate: Candidate) -> None: # type: ignore[override]
self.uninstall(dist)
self.install(candidate)

install_manager = mocker.MagicMock()
install_manager.install.side_effect = install
install_manager.uninstall.side_effect = uninstall
mocker.patch("pdm.installers.Synchronizer.get_manager", return_value=install_manager)
mocker.patch.object(Core, "install_manager_class", MockInstallManager)

return rv

Expand Down

0 comments on commit 62ceee3

Please sign in to comment.