From 01eb695967aa78d220b1f3cb91f8932eb655d18e Mon Sep 17 00:00:00 2001 From: David Hotham Date: Sun, 27 Nov 2022 15:31:36 +0000 Subject: [PATCH] simplify marker simplification --- src/poetry/core/version/markers.py | 251 ++++++++++++----------------- tests/version/test_markers.py | 6 +- 2 files changed, 105 insertions(+), 152 deletions(-) diff --git a/src/poetry/core/version/markers.py b/src/poetry/core/version/markers.py index 8915d53c9..304889a13 100644 --- a/src/poetry/core/version/markers.py +++ b/src/poetry/core/version/markers.py @@ -258,19 +258,21 @@ def value(self) -> str: def intersect(self, other: BaseMarker) -> BaseMarker: if isinstance(other, SingleMarker): - return MultiMarker.of(self, other) + merged = _merge_single_markers(self, other, MultiMarker) + if merged is not None: + return merged + + return MultiMarker(self, other) return other.intersect(self) def union(self, other: BaseMarker) -> BaseMarker: if isinstance(other, SingleMarker): - if self == other: - return self + merged = _merge_single_markers(self, other, MarkerUnion) + if merged is not None: + return merged - if self == other.invert(): - return AnyMarker() - - return MarkerUnion.of(self, other) + return MarkerUnion(self, other) return other.union(self) @@ -337,7 +339,7 @@ def invert(self) -> BaseMarker: max_ = self._constraint.max max_operator = "<=" if self._constraint.include_max else "<" - return MultiMarker.of( + return MultiMarker( SingleMarker(self._name, f"{min_operator} {min_}"), SingleMarker(self._name, f"{max_operator} {max_}"), ).invert() @@ -380,12 +382,11 @@ def _flatten_markers( class MultiMarker(BaseMarker): def __init__(self, *markers: BaseMarker) -> None: - self._markers = [] - - flattened_markers = _flatten_markers(markers, MultiMarker) + self._markers = _flatten_markers(markers, MultiMarker) - for m in flattened_markers: - self._markers.append(m) + @property + def markers(self) -> list[BaseMarker]: + return self._markers @classmethod def of(cls, *markers: BaseMarker) -> BaseMarker: @@ -402,36 +403,22 @@ def of(cls, *markers: BaseMarker) -> BaseMarker: if marker.is_any(): continue - if isinstance(marker, SingleMarker): - intersected = False - for i, mark in enumerate(new_markers): - if isinstance(mark, SingleMarker) and ( - mark.name == marker.name - or {mark.name, marker.name} == PYTHON_VERSION_MARKERS - ): - new_marker = _merge_single_markers(mark, marker, cls) - if new_marker is not None: - new_markers[i] = new_marker - intersected = True - - elif isinstance(mark, MarkerUnion): - intersection = mark.intersect(marker) - if isinstance(intersection, SingleMarker): - new_markers[i] = intersection - elif intersection.is_empty(): - return EmptyMarker() - if intersected: - continue - - elif isinstance(marker, MarkerUnion): - for mark in new_markers: - if isinstance(mark, SingleMarker): - intersection = marker.intersect(mark) - if isinstance(intersection, SingleMarker): - marker = intersection - break - elif intersection.is_empty(): - return EmptyMarker() + intersected = False + for i, mark in enumerate(new_markers): + # If we have a SingleMarker then with any luck after intersection + # it'll become another SingleMarker. + if isinstance(mark, SingleMarker): + new_marker = marker.intersect(mark) + if new_marker.is_empty(): + return EmptyMarker() + + if isinstance(new_marker, SingleMarker): + new_markers[i] = new_marker + intersected = True + break + + if intersected: + continue new_markers.append(marker) @@ -443,34 +430,20 @@ def of(cls, *markers: BaseMarker) -> BaseMarker: return MultiMarker(*new_markers) - @property - def markers(self) -> list[BaseMarker]: - return self._markers - def intersect(self, other: BaseMarker) -> BaseMarker: - if other.is_any(): - return self - - if other.is_empty(): - return other + multi = MultiMarker(self, other) + return dnf(multi) + def union(self, other: BaseMarker) -> BaseMarker: if isinstance(other, MarkerUnion): - return other.intersect(self) - - new_markers = self._markers + [other] - - multi = MultiMarker.of(*new_markers) + return other.union(self) - if isinstance(multi, MultiMarker): - return dnf(multi) - - return multi - - def union(self, other: BaseMarker) -> BaseMarker: - if isinstance(other, (SingleMarker, MultiMarker)): - return MarkerUnion.of(self, other) + union = MarkerUnion(self, other) + conjunction = cnf(union) + if not isinstance(conjunction, MultiMarker): + return conjunction - return other.union(self) + return dnf(conjunction) def union_simplify(self, other: BaseMarker) -> BaseMarker | None: """ @@ -486,20 +459,20 @@ def union_simplify(self, other: BaseMarker) -> BaseMarker | None: new_markers = [] for marker in self._markers: union = marker.union(other) - if not union.is_any(): - new_markers.append(union) + if union.is_any(): + return AnyMarker() + + new_markers.append(union) if len(new_markers) == 1: return new_markers[0] + if other in new_markers and all( other == m or isinstance(m, MarkerUnion) and other in m.markers for m in new_markers ): return other - if not any(isinstance(m, MarkerUnion) for m in new_markers): - return self.of(*new_markers) - elif isinstance(other, MultiMarker): common_markers = [ marker for marker in self.markers if marker in other.markers @@ -518,28 +491,11 @@ def union_simplify(self, other: BaseMarker) -> BaseMarker | None: return other if common_markers: - unique_union = self.of(*unique_markers).union( - self.of(*other_unique_markers) + unique_union = MultiMarker(*unique_markers).union( + MultiMarker(*other_unique_markers) ) if not isinstance(unique_union, MarkerUnion): - return self.of(*common_markers).intersect(unique_union) - - else: - # Usually this operation just complicates things, but the special case - # where it doesn't allows the collapse of adjacent ranges eg - # - # 'python_version >= "3.6" and python_version < "3.6.2"' union - # 'python_version >= "3.6.2" and python_version < "3.7"' -> - # - # 'python_version >= "3.6" and python_version < "3.7"'. - unions = [ - m1.union(m2) for m2 in other_unique_markers for m1 in unique_markers - ] - conjunction = self.of(*unions) - if not isinstance(conjunction, MultiMarker) or not any( - isinstance(m, MarkerUnion) for m in conjunction.markers - ): - return conjunction + return MultiMarker(*common_markers).intersect(unique_union) return None @@ -582,7 +538,7 @@ def only(self, *marker_names: str) -> BaseMarker: def invert(self) -> BaseMarker: markers = [marker.invert() for marker in self._markers] - return MarkerUnion.of(*markers) + return MarkerUnion(*markers) def __eq__(self, other: object) -> bool: if not isinstance(other, MultiMarker): @@ -610,7 +566,7 @@ def __str__(self) -> str: class MarkerUnion(BaseMarker): def __init__(self, *markers: BaseMarker) -> None: - self._markers = list(markers) + self._markers = _flatten_markers(markers, MarkerUnion) @property def markers(self) -> list[BaseMarker]: @@ -625,34 +581,30 @@ def of(cls, *markers: BaseMarker) -> BaseMarker: old_markers = new_markers new_markers = [] for marker in old_markers: - if marker in new_markers or marker.is_empty(): + if marker in new_markers: + continue + + if marker.is_empty(): continue included = False + for i, mark in enumerate(new_markers): + # If we have a SingleMarker then with any luck after union it'll + # become another SingleMarker. + if isinstance(mark, SingleMarker): + new_marker = marker.union(mark) + if new_marker.is_any(): + return AnyMarker() + + if isinstance(new_marker, SingleMarker): + new_markers[i] = new_marker + included = True + break - if isinstance(marker, SingleMarker): - for i, mark in enumerate(new_markers): - if isinstance(mark, SingleMarker) and ( - mark.name == marker.name - or {mark.name, marker.name} == PYTHON_VERSION_MARKERS - ): - new_marker = _merge_single_markers(mark, marker, cls) - if new_marker is not None: - new_markers[i] = new_marker - included = True - break - - elif isinstance(mark, MultiMarker): - union = mark.union_simplify(marker) - if union is not None: - new_markers[i] = union - included = True - break - - elif isinstance(marker, MultiMarker): - included = False - for i, mark in enumerate(new_markers): - union = marker.union_simplify(mark) + # If we have a MultiMarker then we can look for the simplifications + # implemented in union_simplify(). + elif isinstance(mark, MultiMarker): + union = mark.union_simplify(marker) if union is not None: new_markers[i] = union included = True @@ -661,8 +613,9 @@ def of(cls, *markers: BaseMarker) -> BaseMarker: if included: # flatten again because union_simplify may return a union new_markers = _flatten_markers(new_markers, MarkerUnion) - else: - new_markers.append(marker) + continue + + new_markers.append(marker) if any(m.is_any() for m in new_markers): return AnyMarker() @@ -682,39 +635,20 @@ def append(self, marker: BaseMarker) -> None: self._markers.append(marker) def intersect(self, other: BaseMarker) -> BaseMarker: - if other.is_any(): - return self - - if other.is_empty(): - return other - - new_markers = [] - if isinstance(other, (SingleMarker, MultiMarker)): - for marker in self._markers: - intersection = marker.intersect(other) - - if not intersection.is_empty(): - new_markers.append(intersection) - elif isinstance(other, MarkerUnion): - for our_marker in self._markers: - for their_marker in other.markers: - intersection = our_marker.intersect(their_marker) - - if not intersection.is_empty(): - new_markers.append(intersection) + if isinstance(other, MultiMarker): + return other.intersect(self) - return MarkerUnion.of(*new_markers) + multi = MultiMarker(self, other) + return dnf(multi) def union(self, other: BaseMarker) -> BaseMarker: - if other.is_any(): - return other + union = MarkerUnion(self, other) - if other.is_empty(): - return self + conjunction = cnf(union) + if not isinstance(conjunction, MultiMarker): + return conjunction - new_markers = self._markers + [other] - - return MarkerUnion.of(*new_markers) + return dnf(conjunction) def validate(self, environment: dict[str, Any] | None) -> bool: return any(m.validate(environment) for m in self._markers) @@ -756,8 +690,7 @@ def only(self, *marker_names: str) -> BaseMarker: def invert(self) -> BaseMarker: markers = [marker.invert() for marker in self._markers] - - return MultiMarker.of(*markers) + return MultiMarker(*markers) def __eq__(self, other: object) -> bool: if not isinstance(other, MarkerUnion): @@ -845,6 +778,23 @@ def _compact_markers(tree_elements: Tree, tree_prefix: str = "") -> BaseMarker: return MarkerUnion.of(*groups) +def cnf(marker: BaseMarker) -> BaseMarker: + """Transforms the marker into CNF (conjunctive normal form).""" + if isinstance(marker, MarkerUnion): + cnf_markers = [cnf(m) for m in marker.markers] + sub_marker_lists = [ + m.markers if isinstance(m, MultiMarker) else [m] for m in cnf_markers + ] + return MultiMarker.of( + *[MarkerUnion.of(*c) for c in itertools.product(*sub_marker_lists)] + ) + + if isinstance(marker, MultiMarker): + return MultiMarker.of(*[cnf(m) for m in marker.markers]) + + return marker + + def dnf(marker: BaseMarker) -> BaseMarker: """Transforms the marker into DNF (disjunctive normal form).""" if isinstance(marker, MultiMarker): @@ -868,6 +818,9 @@ def _merge_single_markers( if {marker1.name, marker2.name} == PYTHON_VERSION_MARKERS: return _merge_python_version_single_markers(marker1, marker2, merge_class) + if marker1.name != marker2.name: + return None + if merge_class == MultiMarker: merge_method = marker1.constraint.intersect else: diff --git a/tests/version/test_markers.py b/tests/version/test_markers.py index 345dca1e6..c03e5ff78 100644 --- a/tests/version/test_markers.py +++ b/tests/version/test_markers.py @@ -511,16 +511,16 @@ def test_multi_marker_union_multi_is_multi( 'python_full_version >= "3.6.2" and python_version < "3.7"', 'python_version >= "3.6" and python_version < "3.7"', ), - # Ranges with same end. Ideally the union would give the lower version first. + # Ranges with same end. ( 'python_version >= "3.6" and python_version < "3.7"', 'python_full_version >= "3.6.2" and python_version < "3.7"', - 'python_version < "3.7" and python_version >= "3.6"', + 'python_version >= "3.6" and python_version < "3.7"', ), ( 'python_version >= "3.6" and python_version <= "3.7"', 'python_full_version >= "3.6.2" and python_version <= "3.7"', - 'python_version <= "3.7" and python_version >= "3.6"', + 'python_version >= "3.6" and python_version <= "3.7"', ), # A range covers an exact marker. (