diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f147eb6..85941e1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ All versions prior to 0.0.9 are untracked. ### Added +* CLI: The `--fix` flag has been added, allowing users to attempt to + automatically upgrade any vulnerable dependencies to the first safe version + available (#[212](https://github.com/trailofbits/pip-audit/pull/212)) + ### Changed ### Fixed diff --git a/README.md b/README.md index e3ed3005..33754fef 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ python -m pip_audit --help usage: pip-audit [-h] [-V] [-l] [-r REQUIREMENTS] [-f FORMAT] [-s SERVICE] [-d] [-S] [--desc [{on,off,auto}]] [--cache-dir CACHE_DIR] [--progress-spinner {on,off}] [--timeout TIMEOUT] - [--path PATHS] [-v] + [--path PATHS] [-v] [--fix] audit the Python environment for dependencies with known vulnerabilities @@ -111,6 +111,8 @@ optional arguments: -v, --verbose give more output; this setting overrides the `PIP_AUDIT_LOGLEVEL` variable and is equivalent to setting it to `debug` (default: False) + --fix automatically upgrade dependencies with known + vulnerabilities (default: False) ``` @@ -216,6 +218,16 @@ Found 2 known vulnerabilities in 1 packages ] ``` +Audit and attempt to automatically upgrade vulnerable dependencies: +``` +$ pip-audit --fix +Found 2 known vulnerabilities in 1 packages and fixed 2 vulnerabilities in 1 packages +Name Version ID Fix Versions +----- ------- -------------- ------------ +Flask 0.5 PYSEC-2019-179 1.0 +Flask 0.5 PYSEC-2018-66 0.12.3 +``` + ## Security Model This section exists to describe the security assumptions you **can** and **must not** diff --git a/pip_audit/_cli.py b/pip_audit/_cli.py index f8e3f0f5..a5608509 100644 --- a/pip_audit/_cli.py +++ b/pip_audit/_cli.py @@ -19,6 +19,8 @@ RequirementSource, ResolveLibResolver, ) +from pip_audit._dependency_source.interface import DependencySourceError +from pip_audit._fix import ResolvedFixVersion, SkippedFixVersion, resolve_fix_versions from pip_audit._format import ColumnsFormat, CycloneDxFormat, JsonFormat, VulnerabilityFormat from pip_audit._service import OsvService, PyPIService, VulnerabilityService from pip_audit._service.interface import ResolvedDependency, SkippedDependency @@ -234,6 +236,11 @@ def audit() -> None: help="give more output; this setting overrides the `PIP_AUDIT_LOGLEVEL` variable and is " "equivalent to setting it to `debug`", ) + parser.add_argument( + "--fix", + action="store_true", + help="automatically upgrade dependencies with known vulnerabilities", + ) args = parser.parse_args() if args.verbose: @@ -280,11 +287,34 @@ def audit() -> None: pkg_count += 1 vuln_count += len(vulns) + # If the `--fix` flag has been applied, find a set of suitable fix versions and upgrade the + # dependencies at the source + fixes = list() + fixed_pkg_count = 0 + fixed_vuln_count = 0 + if args.fix: + for fix_version in resolve_fix_versions(service, result): + if not fix_version.is_skipped(): + fix_version = cast(ResolvedFixVersion, fix_version) + try: + source.fix(fix_version) + fixed_pkg_count += 1 + fixed_vuln_count += len(result[fix_version.dep]) + except DependencySourceError as dse: + fix_version = SkippedFixVersion(fix_version.dep, str(dse)) + fixes.append(fix_version) + # TODO(ww): Refine this: we should always output if our output format is an SBOM # or other manifest format (like the default JSON format). if vuln_count > 0: - print(f"Found {vuln_count} known vulnerabilities in {pkg_count} packages", file=sys.stderr) + summary_msg = f"Found {vuln_count} known vulnerabilities in {pkg_count} packages" + if args.fix: + summary_msg += ( + f" and fixed {fixed_vuln_count} vulnerabilities in {fixed_pkg_count} packages" + ) + print(summary_msg, file=sys.stderr) print(formatter.format(result)) - sys.exit(1) + if pkg_count != fixed_pkg_count: + sys.exit(1) else: print("No known vulnerabilities found", file=sys.stderr) diff --git a/pip_audit/_dependency_source/__init__.py b/pip_audit/_dependency_source/__init__.py index 66c8cd05..ab19146d 100644 --- a/pip_audit/_dependency_source/__init__.py +++ b/pip_audit/_dependency_source/__init__.py @@ -3,6 +3,7 @@ """ from .interface import ( + DependencyFixError, DependencyResolver, DependencyResolverError, DependencySource, @@ -13,6 +14,7 @@ from .resolvelib import ResolveLibResolver __all__ = [ + "DependencyFixError", "DependencyResolver", "DependencyResolverError", "DependencySource", diff --git a/pip_audit/_dependency_source/interface.py b/pip_audit/_dependency_source/interface.py index 13f49c39..6bd5590f 100644 --- a/pip_audit/_dependency_source/interface.py +++ b/pip_audit/_dependency_source/interface.py @@ -8,6 +8,7 @@ from packaging.requirements import Requirement +from pip_audit._fix import ResolvedFixVersion from pip_audit._service import Dependency @@ -26,6 +27,13 @@ def collect(self) -> Iterator[Dependency]: # pragma: no cover """ raise NotImplementedError + @abstractmethod + def fix(self, fix_version: ResolvedFixVersion) -> None: # pragma: no cover + """ + Upgrade a dependency to the given fix version. + """ + raise NotImplementedError + class DependencySourceError(Exception): """ @@ -38,6 +46,18 @@ class DependencySourceError(Exception): pass +class DependencyFixError(Exception): + """ + Raised when a `DependencySource` fails to perform a "fix" operation, i.e. + fails to upgrade a package to a different version. + + Concrete implementations are expected to subclass this exception to provide + more context. + """ + + pass + + class DependencyResolver(ABC): """ Represents an abstract resolver of Python dependencies that takes a single diff --git a/pip_audit/_dependency_source/pip.py b/pip_audit/_dependency_source/pip.py index 61c574b9..db19559f 100644 --- a/pip_audit/_dependency_source/pip.py +++ b/pip_audit/_dependency_source/pip.py @@ -4,13 +4,16 @@ """ import logging +import subprocess +import sys from pathlib import Path from typing import Iterator, Sequence import pip_api from packaging.version import InvalidVersion, Version -from pip_audit._dependency_source import DependencySource, DependencySourceError +from pip_audit._dependency_source import DependencyFixError, DependencySource, DependencySourceError +from pip_audit._fix import ResolvedFixVersion from pip_audit._service import Dependency, ResolvedDependency, SkippedDependency from pip_audit._state import AuditState @@ -87,8 +90,35 @@ def collect(self) -> Iterator[Dependency]: except Exception as e: raise PipSourceError("failed to list installed distributions") from e + def fix(self, fix_version: ResolvedFixVersion) -> None: + """ + Fixes a dependency version in this `PipSource`. + """ + fix_cmd = [ + sys.executable, + "-m", + "pip", + "install", + f"{fix_version.dep.canonical_name}=={fix_version.version}", + ] + try: + subprocess.run( + fix_cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + ) + except subprocess.CalledProcessError as cpe: + raise PipFixError( + f"failed to upgrade dependency {fix_version.dep.name} to fix version " + f"{fix_version.version}" + ) from cpe + class PipSourceError(DependencySourceError): """A `pip` specific `DependencySourceError`.""" pass + + +class PipFixError(DependencyFixError): + """A `pip` specific `DependencyFixError`.""" + + pass diff --git a/pip_audit/_dependency_source/requirement.py b/pip_audit/_dependency_source/requirement.py index cf3b3c28..a8804f93 100644 --- a/pip_audit/_dependency_source/requirement.py +++ b/pip_audit/_dependency_source/requirement.py @@ -15,6 +15,7 @@ DependencySource, DependencySourceError, ) +from pip_audit._fix import ResolvedFixVersion from pip_audit._service import Dependency from pip_audit._service.interface import ResolvedDependency, SkippedDependency from pip_audit._state import AuditState @@ -78,6 +79,12 @@ def collect(self) -> Iterator[Dependency]: except DependencyResolverError as dre: raise RequirementSourceError("dependency resolver raised an error") from dre + def fix(self, fix_version: ResolvedFixVersion) -> None: # pragma: no cover + """ + Fixes a dependency version for this `RequirementSource`. + """ + raise NotImplementedError + class RequirementSourceError(DependencySourceError): """A requirements-parsing specific `DependencySourceError`.""" diff --git a/pip_audit/_fix.py b/pip_audit/_fix.py new file mode 100644 index 00000000..33cb203b --- /dev/null +++ b/pip_audit/_fix.py @@ -0,0 +1,111 @@ +""" +Functionality for resolving fixed versions of dependencies. +""" + +from dataclasses import dataclass +from typing import Dict, Iterator, List, cast + +from packaging.version import Version + +from pip_audit._service import ( + Dependency, + ResolvedDependency, + VulnerabilityResult, + VulnerabilityService, +) + + +@dataclass(frozen=True) +class FixVersion: + """ + Represents an abstract dependency fix version. + + This class cannot be constructed directly. + """ + + dep: ResolvedDependency + + def __init__(self, *_args, **_kwargs) -> None: # pragma: no cover + """ + A stub constructor that always fails. + """ + raise NotImplementedError + + def is_skipped(self) -> bool: + """ + Check whether the `FixVersion` was unable to be resolved. + """ + return self.__class__ is SkippedFixVersion + + +@dataclass(frozen=True) +class ResolvedFixVersion(FixVersion): + """ + Represents a resolved fix version. + """ + + version: Version + + +@dataclass(frozen=True) +class SkippedFixVersion(FixVersion): + """ + Represents a fix version that was unable to be resolved and therefore, skipped. + """ + + skip_reason: str + + +def resolve_fix_versions( + service: VulnerabilityService, result: Dict[Dependency, List[VulnerabilityResult]] +) -> Iterator[FixVersion]: + """ + Resolves a mapping of dependencies to known vulnerabilities to a series of fix versions without + known vulnerabilties. + """ + for (dep, vulns) in result.items(): + if dep.is_skipped(): + continue + if not vulns: + continue + dep = cast(ResolvedDependency, dep) + try: + version = _resolve_fix_version(service, dep, vulns) + yield ResolvedFixVersion(dep, version) + except FixResolutionImpossible as fri: + yield SkippedFixVersion(dep, str(fri)) + + +def _resolve_fix_version( + service: VulnerabilityService, dep: ResolvedDependency, vulns: List[VulnerabilityResult] +) -> Version: + # We need to upgrade to a fix version that satisfies all vulnerability results + # + # However, whenever we upgrade a dependency, we run the risk of introducing new vulnerabilities + # so we need to run this in a loop and continue polling the vulnerability service on each + # prospective resolved fix version + current_version = dep.version + current_vulns = vulns + while current_vulns: + + def get_earliest_fix_version(d: ResolvedDependency, v: VulnerabilityResult) -> Version: + for fix_version in v.fix_versions: + if fix_version > current_version: + return fix_version + raise FixResolutionImpossible( + f"failed to fix dependency {dep.name} ({dep.version}), unable to find fix version " + f"for vulnerability {v.id}" + ) + + # We want to retrieve a version that potentially fixes all vulnerabilities + current_version = max([get_earliest_fix_version(dep, v) for v in current_vulns]) + _, current_vulns = service.query(ResolvedDependency(dep.name, current_version)) + return current_version + + +class FixResolutionImpossible(Exception): + """ + Raised when `resolve_fix_versions` fails to find a fix version without known vulnerabilities + """ + + pass diff --git a/test/conftest.py b/test/conftest.py index b6366be8..a0e97ea3 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -48,6 +48,9 @@ class Source(DependencySource): def collect(self): yield spec("1.0.1") + def fix(self, _) -> None: + raise NotImplementedError + return Source diff --git a/test/dependency_source/test_pip.py b/test/dependency_source/test_pip.py index 9a62aa5d..c8506ab8 100644 --- a/test/dependency_source/test_pip.py +++ b/test/dependency_source/test_pip.py @@ -1,4 +1,6 @@ import os +import subprocess +import sys from dataclasses import dataclass from typing import Dict, List @@ -8,6 +10,7 @@ from packaging.version import Version from pip_audit._dependency_source import pip +from pip_audit._fix import ResolvedFixVersion from pip_audit._service.interface import ResolvedDependency, SkippedDependency @@ -82,3 +85,35 @@ def mock_installed_distributions( in specs ) assert ResolvedDependency(name="pip-api", version=Version("1.0")) in specs + + +def test_pip_source_fix(monkeypatch): + source = pip.PipSource() + + fix_version = ResolvedFixVersion( + dep=ResolvedDependency(name="pip-api", version=Version("1.0")), version=Version("1.5") + ) + + def run_mock(args, **kwargs): + assert " ".join(args) == f"{sys.executable} -m pip install pip-api==1.5" + + monkeypatch.setattr(subprocess, "run", run_mock) + + source.fix(fix_version) + + +def test_pip_source_fix_failure(monkeypatch): + source = pip.PipSource() + + fix_version = ResolvedFixVersion( + dep=ResolvedDependency(name="pip-api", version=Version("1.0")), version=Version("1.5") + ) + + def run_mock(args, **kwargs): + assert " ".join(args) == f"{sys.executable} -m pip install pip-api==1.5" + raise subprocess.CalledProcessError(-1, str()) + + monkeypatch.setattr(subprocess, "run", run_mock) + + with pytest.raises(pip.PipFixError): + source.fix(fix_version) diff --git a/test/test_fix.py b/test/test_fix.py new file mode 100644 index 00000000..71cfe789 --- /dev/null +++ b/test/test_fix.py @@ -0,0 +1,69 @@ +from typing import Dict, List + +from packaging.version import Version + +from pip_audit._fix import ResolvedFixVersion, SkippedFixVersion, resolve_fix_versions +from pip_audit._service import ( + Dependency, + ResolvedDependency, + SkippedDependency, + VulnerabilityResult, +) + + +def test_fix(vuln_service): + dep = ResolvedDependency(name="foo", version=Version("0.5.0")) + result: Dict[Dependency, List[VulnerabilityResult]] = { + dep: [ + VulnerabilityResult( + id="fake-id", + description="this is not a real result", + fix_versions=[Version("1.0.0")], + ) + ] + } + fix_versions = list(resolve_fix_versions(vuln_service(), result)) + assert len(fix_versions) == 1 + assert fix_versions[0] == ResolvedFixVersion(dep=dep, version=Version("1.1.0")) + assert not fix_versions[0].is_skipped() + + +def test_fix_skipped_deps(vuln_service): + dep = SkippedDependency(name="foo", skip_reason="skip-reason") + result: Dict[Dependency, List[VulnerabilityResult]] = { + dep: [ + VulnerabilityResult( + id="fake-id", + description="this is not a real result", + fix_versions=[Version("1.0.0")], + ) + ] + } + fix_versions = list(resolve_fix_versions(vuln_service(), result)) + assert not fix_versions + + +def test_fix_no_vulns(vuln_service): + dep = ResolvedDependency(name="foo", version=Version("0.5.0")) + result: Dict[Dependency, List[VulnerabilityResult]] = {dep: list()} + fix_versions = list(resolve_fix_versions(vuln_service(), result)) + assert not fix_versions + + +def test_fix_resolution_impossible(vuln_service): + dep = ResolvedDependency(name="foo", version=Version("0.5.0")) + result: Dict[Dependency, List[VulnerabilityResult]] = { + dep: [ + VulnerabilityResult( + id="fake-id", description="this is not a real result", fix_versions=list() + ) + ] + } + fix_versions = list(resolve_fix_versions(vuln_service(), result)) + assert len(fix_versions) == 1 + assert fix_versions[0] == SkippedFixVersion( + dep=dep, + skip_reason="failed to fix dependency foo (0.5.0), unable to find fix version for " + "vulnerability fake-id", + ) + assert fix_versions[0].is_skipped()