|
3 | 3 | import collections |
4 | 4 | import itertools |
5 | 5 | import operator |
6 | | -from typing import TYPE_CHECKING, Collection, Generic, Iterable, Mapping |
| 6 | +from typing import TYPE_CHECKING, Generic |
7 | 7 |
|
8 | 8 | from ..structs import ( |
9 | 9 | CT, |
|
27 | 27 | ) |
28 | 28 |
|
29 | 29 | if TYPE_CHECKING: |
| 30 | + from collections.abc import Collection, Iterable, Mapping |
| 31 | + |
30 | 32 | from ..providers import AbstractProvider, Preference |
31 | 33 | from ..reporters import BaseReporter |
32 | 34 |
|
| 35 | +_OPTIMISTIC_BACKJUMPING_RATIO: float = 0.1 |
| 36 | + |
33 | 37 |
|
34 | 38 | def _build_result(state: State[RT, CT, KT]) -> Result[RT, CT, KT]: |
35 | 39 | mapping = state.mapping |
@@ -77,6 +81,11 @@ def __init__( |
77 | 81 | self._r = reporter |
78 | 82 | self._states: list[State[RT, CT, KT]] = [] |
79 | 83 |
|
| 84 | + # Optimistic backjumping variables |
| 85 | + self._optimistic_backjumping_ratio = _OPTIMISTIC_BACKJUMPING_RATIO |
| 86 | + self._save_states: list[State[RT, CT, KT]] | None = None |
| 87 | + self._optimistic_start_round: int | None = None |
| 88 | + |
80 | 89 | @property |
81 | 90 | def state(self) -> State[RT, CT, KT]: |
82 | 91 | try: |
@@ -274,6 +283,25 @@ def _patch_criteria( |
274 | 283 | ) |
275 | 284 | return True |
276 | 285 |
|
| 286 | + def _save_state(self) -> None: |
| 287 | + """Save states for potential rollback if optimistic backjumping fails.""" |
| 288 | + if self._save_states is None: |
| 289 | + self._save_states = [ |
| 290 | + State( |
| 291 | + mapping=s.mapping.copy(), |
| 292 | + criteria=s.criteria.copy(), |
| 293 | + backtrack_causes=s.backtrack_causes[:], |
| 294 | + ) |
| 295 | + for s in self._states |
| 296 | + ] |
| 297 | + |
| 298 | + def _rollback_states(self) -> None: |
| 299 | + """Rollback states and disable optimistic backjumping.""" |
| 300 | + self._optimistic_backjumping_ratio = 0.0 |
| 301 | + if self._save_states: |
| 302 | + self._states = self._save_states |
| 303 | + self._save_states = None |
| 304 | + |
277 | 305 | def _backjump(self, causes: list[RequirementInformation[RT, CT]]) -> bool: |
278 | 306 | """Perform backjumping. |
279 | 307 |
|
@@ -324,13 +352,26 @@ def _backjump(self, causes: list[RequirementInformation[RT, CT]]) -> bool: |
324 | 352 | except (IndexError, KeyError): |
325 | 353 | raise ResolutionImpossible(causes) from None |
326 | 354 |
|
327 | | - # Only backjump if the current broken state is |
328 | | - # an incompatible dependency |
329 | | - if name not in incompatible_deps: |
| 355 | + if ( |
| 356 | + not self._optimistic_backjumping_ratio |
| 357 | + and name not in incompatible_deps |
| 358 | + ): |
| 359 | + # For safe backjumping only backjump if the current dependency |
| 360 | + # is not the same as the incompatible dependency |
330 | 361 | break |
331 | 362 |
|
| 363 | + # On the first time a non-safe backjump is done the state |
| 364 | + # is saved so we can restore it later if the resolution fails |
| 365 | + if ( |
| 366 | + self._optimistic_backjumping_ratio |
| 367 | + and self._save_states is None |
| 368 | + and name not in incompatible_deps |
| 369 | + ): |
| 370 | + self._save_state() |
| 371 | + |
332 | 372 | # If the current dependencies and the incompatible dependencies |
333 | | - # are overlapping then we have found a cause of the incompatibility |
| 373 | + # are overlapping then we have likely found a cause of the |
| 374 | + # incompatibility |
334 | 375 | current_dependencies = { |
335 | 376 | self._p.identify(d) for d in self._p.get_dependencies(candidate) |
336 | 377 | } |
@@ -394,9 +435,32 @@ def resolve(self, requirements: Iterable[RT], max_rounds: int) -> State[RT, CT, |
394 | 435 | # pinning the virtual "root" package in the graph. |
395 | 436 | self._push_new_state() |
396 | 437 |
|
| 438 | + # Variables for optimistic backjumping |
| 439 | + optimistic_rounds_cutoff: int | None = None |
| 440 | + optimistic_backjumping_start_round: int | None = None |
| 441 | + |
397 | 442 | for round_index in range(max_rounds): |
398 | 443 | self._r.starting_round(index=round_index) |
399 | 444 |
|
| 445 | + # Handle if optimistic backjumping has been running for too long |
| 446 | + if self._optimistic_backjumping_ratio and self._save_states is not None: |
| 447 | + if optimistic_backjumping_start_round is None: |
| 448 | + optimistic_backjumping_start_round = round_index |
| 449 | + optimistic_rounds_cutoff = int( |
| 450 | + (max_rounds - round_index) * self._optimistic_backjumping_ratio |
| 451 | + ) |
| 452 | + |
| 453 | + if optimistic_rounds_cutoff <= 0: |
| 454 | + self._rollback_states() |
| 455 | + continue |
| 456 | + elif optimistic_rounds_cutoff is not None: |
| 457 | + if ( |
| 458 | + round_index - optimistic_backjumping_start_round |
| 459 | + >= optimistic_rounds_cutoff |
| 460 | + ): |
| 461 | + self._rollback_states() |
| 462 | + continue |
| 463 | + |
400 | 464 | unsatisfied_names = [ |
401 | 465 | key |
402 | 466 | for key, criterion in self.state.criteria.items() |
@@ -448,12 +512,29 @@ def resolve(self, requirements: Iterable[RT], max_rounds: int) -> State[RT, CT, |
448 | 512 | # Backjump if pinning fails. The backjump process puts us in |
449 | 513 | # an unpinned state, so we can work on it in the next round. |
450 | 514 | self._r.resolving_conflicts(causes=causes) |
451 | | - success = self._backjump(causes) |
452 | | - self.state.backtrack_causes[:] = causes |
453 | 515 |
|
454 | | - # Dead ends everywhere. Give up. |
455 | | - if not success: |
456 | | - raise ResolutionImpossible(self.state.backtrack_causes) |
| 516 | + try: |
| 517 | + success = self._backjump(causes) |
| 518 | + except ResolutionImpossible: |
| 519 | + if self._optimistic_backjumping_ratio and self._save_states: |
| 520 | + failed_optimistic_backjumping = True |
| 521 | + else: |
| 522 | + raise |
| 523 | + else: |
| 524 | + failed_optimistic_backjumping = bool( |
| 525 | + not success |
| 526 | + and self._optimistic_backjumping_ratio |
| 527 | + and self._save_states |
| 528 | + ) |
| 529 | + |
| 530 | + if failed_optimistic_backjumping and self._save_states: |
| 531 | + self._rollback_states() |
| 532 | + else: |
| 533 | + self.state.backtrack_causes[:] = causes |
| 534 | + |
| 535 | + # Dead ends everywhere. Give up. |
| 536 | + if not success: |
| 537 | + raise ResolutionImpossible(self.state.backtrack_causes) |
457 | 538 | else: |
458 | 539 | # discard as information sources any invalidated names |
459 | 540 | # (unsatisfied names that were previously satisfied) |
|
0 commit comments