Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pip, _fix: Implement --fix for PipSource #212

Merged
merged 17 commits into from
Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pip_audit/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
RequirementSource,
ResolveLibResolver,
)
from pip_audit._fix import resolve_fix_versions
from pip_audit._format import ColumnsFormat, CycloneDxFormat, JsonFormat, VulnerabilityFormat
from pip_audit._service import OsvService, PyPIService, VulnerabilityService
from pip_audit._service.interface import ResolvedDependency, SkippedDependency
Expand Down Expand Up @@ -234,6 +235,11 @@ def audit() -> None:
help="give more output; this setting overrides the `PIP_AUDIT_LOGLEVEL` variable and is "
"equivalent to setting it to `debug`",
)
parser.add_argument(
"--fix",
action="store_true",
help="automatically upgrade dependencies with known vulnerabilities",
)

args = parser.parse_args()
if args.verbose:
Expand Down Expand Up @@ -280,6 +286,12 @@ def audit() -> None:
pkg_count += 1
vuln_count += len(vulns)

# If the `--fix` flag has been applied, find a set of suitable fix versions and upgrade the
# dependencies at the source
if args.fix:
fix_versions = resolve_fix_versions(service, result)
source.fix_all(fix_versions)

# TODO(ww): Refine this: we should always output if our output format is an SBOM
# or other manifest format (like the default JSON format).
if vuln_count > 0:
Expand Down
14 changes: 14 additions & 0 deletions pip_audit/_dependency_source/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Iterator, List, Tuple

from packaging.requirements import Requirement
from packaging.version import Version

from pip_audit._service import Dependency

Expand All @@ -26,6 +27,19 @@ def collect(self) -> Iterator[Dependency]: # pragma: no cover
"""
raise NotImplementedError

def fix(self, dep: Dependency, fix_version: Version) -> None:
"""
Upgrade a dependency to the given fix version.
"""
raise NotImplementedError

def fix_all(self, fix_req: Iterator[Tuple[Dependency, Version]]) -> None:
"""
Upgrade a collection of dependencies to their associated fix versions.
"""
for (dep, fix_version) in fix_req:
self.fix(dep, fix_version)


class DependencySourceError(Exception):
"""
Expand Down
14 changes: 14 additions & 0 deletions pip_audit/_dependency_source/pip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"""

import logging
import subprocess
import sys
from pathlib import Path
from typing import Iterator, Sequence

Expand Down Expand Up @@ -87,6 +89,18 @@ def collect(self) -> Iterator[Dependency]:
except Exception as e:
raise PipSourceError("failed to list installed distributions") from e

def fix(self, dep: Dependency, fix_version: Version) -> None:
"""
Fixes a dependency version in this `PipSource`.
"""
fix_cmd = [sys.executable, "-m", "pip", "install", f"{dep.name}=={fix_version}"]
try:
subprocess.run(
fix_cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
except subprocess.CalledProcessError as cpe:
raise RuntimeError from cpe
tetsuo-cpp marked this conversation as resolved.
Show resolved Hide resolved


class PipSourceError(DependencySourceError):
"""A `pip` specific `DependencySourceError`."""
Expand Down
52 changes: 52 additions & 0 deletions pip_audit/_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Resolving fix versions.
"""

from typing import Dict, Iterator, List, Tuple, cast

from packaging.version import Version

from pip_audit._service import (
Dependency,
ResolvedDependency,
VulnerabilityResult,
VulnerabilityService,
)


def resolve_fix_versions(
service: VulnerabilityService, result: Dict[Dependency, List[VulnerabilityResult]]
) -> Iterator[Tuple[ResolvedDependency, Version]]:
for (dep, vulns) in result.items():
if dep.is_skipped():
continue
if not vulns:
continue
dep = cast(ResolvedDependency, dep)
yield (dep, _resolve_fix_version(service, dep, vulns))


def _resolve_fix_version(
service: VulnerabilityService, dep: ResolvedDependency, vulns: List[VulnerabilityResult]
) -> Version:
# We need to upgrade to a fix version that satisfies all vulnerability results
#
# However, whenever we upgrade a dependency, we run the risk of introducing new vulnerabilities
# so we need to run this in a loop and continue polling the vulnerability service on each
# prospective resolved fix version
current_version = dep.version
current_vulns = vulns
while current_vulns:

def get_earliest_fix_version(fix_versions: List[Version]) -> Version:
for v in fix_versions:
if v > current_version:
return v
raise RuntimeError
tetsuo-cpp marked this conversation as resolved.
Show resolved Hide resolved

# We want to retrieve a version that potentially fixes all vulnerabilities
current_version = max(
[get_earliest_fix_version(v.fix_versions) for v in current_vulns if v.fix_versions]
)
_, current_vulns = service.query(ResolvedDependency(dep.name, current_version))
return current_version