Skip to content

Commit

Permalink
simplify marker simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
dimbleby committed Nov 27, 2022
1 parent 466090c commit 01eb695
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 152 deletions.
251 changes: 102 additions & 149 deletions src/poetry/core/version/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,19 +258,21 @@ def value(self) -> str:

def intersect(self, other: BaseMarker) -> BaseMarker:
if isinstance(other, SingleMarker):
return MultiMarker.of(self, other)
merged = _merge_single_markers(self, other, MultiMarker)
if merged is not None:
return merged

return MultiMarker(self, other)

return other.intersect(self)

def union(self, other: BaseMarker) -> BaseMarker:
if isinstance(other, SingleMarker):
if self == other:
return self
merged = _merge_single_markers(self, other, MarkerUnion)
if merged is not None:
return merged

if self == other.invert():
return AnyMarker()

return MarkerUnion.of(self, other)
return MarkerUnion(self, other)

return other.union(self)

Expand Down Expand Up @@ -337,7 +339,7 @@ def invert(self) -> BaseMarker:
max_ = self._constraint.max
max_operator = "<=" if self._constraint.include_max else "<"

return MultiMarker.of(
return MultiMarker(
SingleMarker(self._name, f"{min_operator} {min_}"),
SingleMarker(self._name, f"{max_operator} {max_}"),
).invert()
Expand Down Expand Up @@ -380,12 +382,11 @@ def _flatten_markers(

class MultiMarker(BaseMarker):
def __init__(self, *markers: BaseMarker) -> None:
self._markers = []

flattened_markers = _flatten_markers(markers, MultiMarker)
self._markers = _flatten_markers(markers, MultiMarker)

for m in flattened_markers:
self._markers.append(m)
@property
def markers(self) -> list[BaseMarker]:
return self._markers

@classmethod
def of(cls, *markers: BaseMarker) -> BaseMarker:
Expand All @@ -402,36 +403,22 @@ def of(cls, *markers: BaseMarker) -> BaseMarker:
if marker.is_any():
continue

if isinstance(marker, SingleMarker):
intersected = False
for i, mark in enumerate(new_markers):
if isinstance(mark, SingleMarker) and (
mark.name == marker.name
or {mark.name, marker.name} == PYTHON_VERSION_MARKERS
):
new_marker = _merge_single_markers(mark, marker, cls)
if new_marker is not None:
new_markers[i] = new_marker
intersected = True

elif isinstance(mark, MarkerUnion):
intersection = mark.intersect(marker)
if isinstance(intersection, SingleMarker):
new_markers[i] = intersection
elif intersection.is_empty():
return EmptyMarker()
if intersected:
continue

elif isinstance(marker, MarkerUnion):
for mark in new_markers:
if isinstance(mark, SingleMarker):
intersection = marker.intersect(mark)
if isinstance(intersection, SingleMarker):
marker = intersection
break
elif intersection.is_empty():
return EmptyMarker()
intersected = False
for i, mark in enumerate(new_markers):
# If we have a SingleMarker then with any luck after intersection
# it'll become another SingleMarker.
if isinstance(mark, SingleMarker):
new_marker = marker.intersect(mark)
if new_marker.is_empty():
return EmptyMarker()

if isinstance(new_marker, SingleMarker):
new_markers[i] = new_marker
intersected = True
break

if intersected:
continue

new_markers.append(marker)

Expand All @@ -443,34 +430,20 @@ def of(cls, *markers: BaseMarker) -> BaseMarker:

return MultiMarker(*new_markers)

@property
def markers(self) -> list[BaseMarker]:
return self._markers

def intersect(self, other: BaseMarker) -> BaseMarker:
if other.is_any():
return self

if other.is_empty():
return other
multi = MultiMarker(self, other)
return dnf(multi)

def union(self, other: BaseMarker) -> BaseMarker:
if isinstance(other, MarkerUnion):
return other.intersect(self)

new_markers = self._markers + [other]

multi = MultiMarker.of(*new_markers)
return other.union(self)

if isinstance(multi, MultiMarker):
return dnf(multi)

return multi

def union(self, other: BaseMarker) -> BaseMarker:
if isinstance(other, (SingleMarker, MultiMarker)):
return MarkerUnion.of(self, other)
union = MarkerUnion(self, other)
conjunction = cnf(union)
if not isinstance(conjunction, MultiMarker):
return conjunction

return other.union(self)
return dnf(conjunction)

def union_simplify(self, other: BaseMarker) -> BaseMarker | None:
"""
Expand All @@ -486,20 +459,20 @@ def union_simplify(self, other: BaseMarker) -> BaseMarker | None:
new_markers = []
for marker in self._markers:
union = marker.union(other)
if not union.is_any():
new_markers.append(union)
if union.is_any():
return AnyMarker()

new_markers.append(union)

if len(new_markers) == 1:
return new_markers[0]

if other in new_markers and all(
other == m or isinstance(m, MarkerUnion) and other in m.markers
for m in new_markers
):
return other

if not any(isinstance(m, MarkerUnion) for m in new_markers):
return self.of(*new_markers)

elif isinstance(other, MultiMarker):
common_markers = [
marker for marker in self.markers if marker in other.markers
Expand All @@ -518,28 +491,11 @@ def union_simplify(self, other: BaseMarker) -> BaseMarker | None:
return other

if common_markers:
unique_union = self.of(*unique_markers).union(
self.of(*other_unique_markers)
unique_union = MultiMarker(*unique_markers).union(
MultiMarker(*other_unique_markers)
)
if not isinstance(unique_union, MarkerUnion):
return self.of(*common_markers).intersect(unique_union)

else:
# Usually this operation just complicates things, but the special case
# where it doesn't allows the collapse of adjacent ranges eg
#
# 'python_version >= "3.6" and python_version < "3.6.2"' union
# 'python_version >= "3.6.2" and python_version < "3.7"' ->
#
# 'python_version >= "3.6" and python_version < "3.7"'.
unions = [
m1.union(m2) for m2 in other_unique_markers for m1 in unique_markers
]
conjunction = self.of(*unions)
if not isinstance(conjunction, MultiMarker) or not any(
isinstance(m, MarkerUnion) for m in conjunction.markers
):
return conjunction
return MultiMarker(*common_markers).intersect(unique_union)

return None

Expand Down Expand Up @@ -582,7 +538,7 @@ def only(self, *marker_names: str) -> BaseMarker:
def invert(self) -> BaseMarker:
markers = [marker.invert() for marker in self._markers]

return MarkerUnion.of(*markers)
return MarkerUnion(*markers)

def __eq__(self, other: object) -> bool:
if not isinstance(other, MultiMarker):
Expand Down Expand Up @@ -610,7 +566,7 @@ def __str__(self) -> str:

class MarkerUnion(BaseMarker):
def __init__(self, *markers: BaseMarker) -> None:
self._markers = list(markers)
self._markers = _flatten_markers(markers, MarkerUnion)

@property
def markers(self) -> list[BaseMarker]:
Expand All @@ -625,34 +581,30 @@ def of(cls, *markers: BaseMarker) -> BaseMarker:
old_markers = new_markers
new_markers = []
for marker in old_markers:
if marker in new_markers or marker.is_empty():
if marker in new_markers:
continue

if marker.is_empty():
continue

included = False
for i, mark in enumerate(new_markers):
# If we have a SingleMarker then with any luck after union it'll
# become another SingleMarker.
if isinstance(mark, SingleMarker):
new_marker = marker.union(mark)
if new_marker.is_any():
return AnyMarker()

if isinstance(new_marker, SingleMarker):
new_markers[i] = new_marker
included = True
break

if isinstance(marker, SingleMarker):
for i, mark in enumerate(new_markers):
if isinstance(mark, SingleMarker) and (
mark.name == marker.name
or {mark.name, marker.name} == PYTHON_VERSION_MARKERS
):
new_marker = _merge_single_markers(mark, marker, cls)
if new_marker is not None:
new_markers[i] = new_marker
included = True
break

elif isinstance(mark, MultiMarker):
union = mark.union_simplify(marker)
if union is not None:
new_markers[i] = union
included = True
break

elif isinstance(marker, MultiMarker):
included = False
for i, mark in enumerate(new_markers):
union = marker.union_simplify(mark)
# If we have a MultiMarker then we can look for the simplifications
# implemented in union_simplify().
elif isinstance(mark, MultiMarker):
union = mark.union_simplify(marker)
if union is not None:
new_markers[i] = union
included = True
Expand All @@ -661,8 +613,9 @@ def of(cls, *markers: BaseMarker) -> BaseMarker:
if included:
# flatten again because union_simplify may return a union
new_markers = _flatten_markers(new_markers, MarkerUnion)
else:
new_markers.append(marker)
continue

new_markers.append(marker)

if any(m.is_any() for m in new_markers):
return AnyMarker()
Expand All @@ -682,39 +635,20 @@ def append(self, marker: BaseMarker) -> None:
self._markers.append(marker)

def intersect(self, other: BaseMarker) -> BaseMarker:
if other.is_any():
return self

if other.is_empty():
return other

new_markers = []
if isinstance(other, (SingleMarker, MultiMarker)):
for marker in self._markers:
intersection = marker.intersect(other)

if not intersection.is_empty():
new_markers.append(intersection)
elif isinstance(other, MarkerUnion):
for our_marker in self._markers:
for their_marker in other.markers:
intersection = our_marker.intersect(their_marker)

if not intersection.is_empty():
new_markers.append(intersection)
if isinstance(other, MultiMarker):
return other.intersect(self)

return MarkerUnion.of(*new_markers)
multi = MultiMarker(self, other)
return dnf(multi)

def union(self, other: BaseMarker) -> BaseMarker:
if other.is_any():
return other
union = MarkerUnion(self, other)

if other.is_empty():
return self
conjunction = cnf(union)
if not isinstance(conjunction, MultiMarker):
return conjunction

new_markers = self._markers + [other]

return MarkerUnion.of(*new_markers)
return dnf(conjunction)

def validate(self, environment: dict[str, Any] | None) -> bool:
return any(m.validate(environment) for m in self._markers)
Expand Down Expand Up @@ -756,8 +690,7 @@ def only(self, *marker_names: str) -> BaseMarker:

def invert(self) -> BaseMarker:
markers = [marker.invert() for marker in self._markers]

return MultiMarker.of(*markers)
return MultiMarker(*markers)

def __eq__(self, other: object) -> bool:
if not isinstance(other, MarkerUnion):
Expand Down Expand Up @@ -845,6 +778,23 @@ def _compact_markers(tree_elements: Tree, tree_prefix: str = "") -> BaseMarker:
return MarkerUnion.of(*groups)


def cnf(marker: BaseMarker) -> BaseMarker:
"""Transforms the marker into CNF (conjunctive normal form)."""
if isinstance(marker, MarkerUnion):
cnf_markers = [cnf(m) for m in marker.markers]
sub_marker_lists = [
m.markers if isinstance(m, MultiMarker) else [m] for m in cnf_markers
]
return MultiMarker.of(
*[MarkerUnion.of(*c) for c in itertools.product(*sub_marker_lists)]
)

if isinstance(marker, MultiMarker):
return MultiMarker.of(*[cnf(m) for m in marker.markers])

return marker


def dnf(marker: BaseMarker) -> BaseMarker:
"""Transforms the marker into DNF (disjunctive normal form)."""
if isinstance(marker, MultiMarker):
Expand All @@ -868,6 +818,9 @@ def _merge_single_markers(
if {marker1.name, marker2.name} == PYTHON_VERSION_MARKERS:
return _merge_python_version_single_markers(marker1, marker2, merge_class)

if marker1.name != marker2.name:
return None

if merge_class == MultiMarker:
merge_method = marker1.constraint.intersect
else:
Expand Down
Loading

0 comments on commit 01eb695

Please sign in to comment.