Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow to consume the initialized solution #915

Merged
merged 13 commits into from
Jun 27, 2024
1 change: 1 addition & 0 deletions .github/workflows/downstream_python_enterprise.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
uses: actions/checkout@v4
with:
path: './timefold-solver'
ref: ${{ github.event.pull_request.head.sha }} # The GHA event will pull the main branch by default, and we must specify the PR reference version

# Need to check for stale repo, since Github is not aware of the build chain and therefore doesn't automate it.
- name: Checkout timefold-solver-python (PR) # Checkout the PR branch first, if it exists
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ default SolverJobBuilder<Solution_, ProblemId_> withProblem(Solution_ problem) {
SolverJobBuilder<Solution_, ProblemId_>
withFinalBestSolutionConsumer(Consumer<? super Solution_> finalBestSolutionConsumer);

/**
* Sets the consumer of the first initialized solution. First initialized solution is the solution at the end of
* the last phase that immediately precedes the first local search phase. This solution marks the beginning of actual
* optimization process.
*
* @param firstInitializedSolutionConsumer never null, called only once before starting the first Local Search phase
* @return this, never null
*/
SolverJobBuilder<Solution_, ProblemId_>
withFirstInitializedSolutionConsumer(Consumer<? super Solution_> firstInitializedSolutionConsumer);

/**
* Sets the custom exception handler.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,10 @@ public static class Builder<Solution_> extends AbstractPhase.Builder<Solution_>
private final EntityPlacer<Solution_> entityPlacer;
private final ConstructionHeuristicDecider<Solution_> decider;

public Builder(int phaseIndex, String logIndentation, Termination<Solution_> phaseTermination,
public Builder(int phaseIndex, boolean triggerFirstInitializedSolutionEvent, String logIndentation,
Termination<Solution_> phaseTermination,
EntityPlacer<Solution_> entityPlacer, ConstructionHeuristicDecider<Solution_> decider) {
super(phaseIndex, logIndentation, phaseTermination);
super(phaseIndex, triggerFirstInitializedSolutionEvent, logIndentation, phaseTermination);
this.entityPlacer = entityPlacer;
this.decider = decider;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ public DefaultConstructionHeuristicPhaseFactory(ConstructionHeuristicPhaseConfig
}

@Override
public ConstructionHeuristicPhase<Solution_> buildPhase(int phaseIndex, HeuristicConfigPolicy<Solution_> solverConfigPolicy,
BestSolutionRecaller<Solution_> bestSolutionRecaller, Termination<Solution_> solverTermination) {
public ConstructionHeuristicPhase<Solution_> buildPhase(int phaseIndex, boolean triggerFirstInitializedSolutionEvent,
HeuristicConfigPolicy<Solution_> solverConfigPolicy, BestSolutionRecaller<Solution_> bestSolutionRecaller,
Termination<Solution_> solverTermination) {
ConstructionHeuristicType constructionHeuristicType_ = Objects.requireNonNullElse(
phaseConfig.getConstructionHeuristicType(),
ConstructionHeuristicType.ALLOCATE_ENTITY_FROM_QUEUE);
Expand All @@ -71,6 +72,7 @@ public ConstructionHeuristicPhase<Solution_> buildPhase(int phaseIndex, Heuristi

DefaultConstructionHeuristicPhase.Builder<Solution_> builder = new DefaultConstructionHeuristicPhase.Builder<>(
phaseIndex,
triggerFirstInitializedSolutionEvent,
solverConfigPolicy.getLogIndentation(),
phaseTermination,
entityPlacer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ public DefaultExhaustiveSearchPhaseFactory(ExhaustiveSearchPhaseConfig phaseConf
}

@Override
public ExhaustiveSearchPhase<Solution_> buildPhase(int phaseIndex, HeuristicConfigPolicy<Solution_> solverConfigPolicy,
BestSolutionRecaller<Solution_> bestSolutionRecaller, Termination<Solution_> solverTermination) {
public ExhaustiveSearchPhase<Solution_> buildPhase(int phaseIndex, boolean triggerFirstInitializedSolutionEvent,
HeuristicConfigPolicy<Solution_> solverConfigPolicy, BestSolutionRecaller<Solution_> bestSolutionRecaller,
Termination<Solution_> solverTermination) {
ExhaustiveSearchType exhaustiveSearchType_ = Objects.requireNonNullElse(
phaseConfig.getExhaustiveSearchType(),
ExhaustiveSearchType.BRANCH_AND_BOUND);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ public DefaultLocalSearchPhaseFactory(LocalSearchPhaseConfig phaseConfig) {
}

@Override
public LocalSearchPhase<Solution_> buildPhase(int phaseIndex, HeuristicConfigPolicy<Solution_> solverConfigPolicy,
BestSolutionRecaller<Solution_> bestSolutionRecaller, Termination<Solution_> solverTermination) {
public LocalSearchPhase<Solution_> buildPhase(int phaseIndex, boolean triggerFirstInitializedSolutionEvent,
HeuristicConfigPolicy<Solution_> solverConfigPolicy, BestSolutionRecaller<Solution_> bestSolutionRecaller,
Termination<Solution_> solverTermination) {
HeuristicConfigPolicy<Solution_> phaseConfigPolicy = solverConfigPolicy.createPhaseConfigPolicy();
Termination<Solution_> phaseTermination = buildPhaseTermination(phaseConfigPolicy, solverTermination);
DefaultLocalSearchPhase.Builder<Solution_> builder =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ public DefaultPartitionedSearchPhaseFactory(PartitionedSearchPhaseConfig phaseCo
}

@Override
public PartitionedSearchPhase<Solution_> buildPhase(int phaseIndex, HeuristicConfigPolicy<Solution_> solverConfigPolicy,
BestSolutionRecaller<Solution_> bestSolutionRecaller, Termination<Solution_> solverTermination) {
public PartitionedSearchPhase<Solution_> buildPhase(int phaseIndex, boolean triggerFirstInitializedSolutionEvent,
HeuristicConfigPolicy<Solution_> solverConfigPolicy, BestSolutionRecaller<Solution_> bestSolutionRecaller,
Termination<Solution_> solverTermination) {
return TimefoldSolverEnterpriseService.loadOrFail(TimefoldSolverEnterpriseService.Feature.PARTITIONED_SEARCH)
.buildPartitionedSearch(phaseIndex, phaseConfig, solverConfigPolicy, solverTermination,
this::buildPhaseTermination);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public abstract class AbstractPhase<Solution_> implements Phase<Solution_> {
protected final boolean assertStepScoreFromScratch;
protected final boolean assertExpectedStepScore;
protected final boolean assertShadowVariablesAreNotStaleAfterStep;
protected final boolean triggerFirstInitializedSolutionEvent;

/** Used for {@link #addPhaseLifecycleListener(PhaseLifecycleListener)}. */
protected PhaseLifecycleSupport<Solution_> phaseLifecycleSupport = new PhaseLifecycleSupport<>();
Expand All @@ -45,6 +46,7 @@ protected AbstractPhase(Builder<Solution_> builder) {
assertStepScoreFromScratch = builder.assertStepScoreFromScratch;
assertExpectedStepScore = builder.assertExpectedStepScore;
assertShadowVariablesAreNotStaleAfterStep = builder.assertShadowVariablesAreNotStaleAfterStep;
triggerFirstInitializedSolutionEvent = builder.triggerFirstInitializedSolutionEvent;
}

public int getPhaseIndex() {
Expand Down Expand Up @@ -77,6 +79,11 @@ public boolean isAssertShadowVariablesAreNotStaleAfterStep() {

public abstract String getPhaseTypeString();

@Override
public boolean triggersFirstInitializedSolutionEvent() {
return triggerFirstInitializedSolutionEvent;
}

// ************************************************************************
// Lifecycle methods
// ************************************************************************
Expand Down Expand Up @@ -209,6 +216,7 @@ but planning list variable (%s) has (%d) unexpected unassigned values.
protected abstract static class Builder<Solution_> {

private final int phaseIndex;
private final boolean triggerFirstInitializedSolutionEvent;
private final String logIndentation;
private final Termination<Solution_> phaseTermination;

Expand All @@ -217,7 +225,13 @@ protected abstract static class Builder<Solution_> {
private boolean assertShadowVariablesAreNotStaleAfterStep = false;

protected Builder(int phaseIndex, String logIndentation, Termination<Solution_> phaseTermination) {
this(phaseIndex, false, logIndentation, phaseTermination);
}

protected Builder(int phaseIndex, boolean triggerFirstInitializedSolutionEvent, String logIndentation,
Termination<Solution_> phaseTermination) {
this.phaseIndex = phaseIndex;
this.triggerFirstInitializedSolutionEvent = triggerFirstInitializedSolutionEvent;
this.logIndentation = logIndentation;
this.phaseTermination = phaseTermination;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ public NoChangePhaseFactory(NoChangePhaseConfig phaseConfig) {
}

@Override
public NoChangePhase<Solution_> buildPhase(int phaseIndex, HeuristicConfigPolicy<Solution_> solverConfigPolicy,
BestSolutionRecaller<Solution_> bestSolutionRecaller, Termination<Solution_> solverTermination) {
public NoChangePhase<Solution_> buildPhase(int phaseIndex, boolean triggerFirstInitializedSolutionEvent,
HeuristicConfigPolicy<Solution_> solverConfigPolicy, BestSolutionRecaller<Solution_> bestSolutionRecaller,
Termination<Solution_> solverTermination) {
HeuristicConfigPolicy<Solution_> phaseConfigPolicy = solverConfigPolicy.createPhaseConfigPolicy();
return new NoChangePhase.Builder<>(phaseIndex, solverConfigPolicy.getLogIndentation(),
buildPhaseTermination(phaseConfigPolicy, solverTermination)).build();
Expand Down
11 changes: 11 additions & 0 deletions core/src/main/java/ai/timefold/solver/core/impl/phase/Phase.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package ai.timefold.solver.core.impl.phase;

import java.util.function.Consumer;

import ai.timefold.solver.core.api.domain.solution.PlanningSolution;
import ai.timefold.solver.core.api.solver.Solver;
import ai.timefold.solver.core.impl.phase.event.PhaseLifecycleListener;
Expand Down Expand Up @@ -36,4 +38,13 @@ public interface Phase<Solution_> extends PhaseLifecycleListener<Solution_> {

void solve(SolverScope<Solution_> solverScope);

/**
* Check if a phase triggers the first initialized solution event. The first initialized solution immediately precedes
* the first {@link ai.timefold.solver.core.impl.localsearch.LocalSearchPhase}.
*
* @see ai.timefold.solver.core.api.solver.SolverJobBuilder#withFirstInitializedSolutionConsumer(Consumer)
*
* @return true if the phase returns the first initialized solution.
*/
boolean triggersFirstInitializedSolutionEvent();
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ static <Solution_> List<Phase<Solution_>> buildPhases(List<PhaseConfig> phaseCon
HeuristicConfigPolicy<Solution_> configPolicy, BestSolutionRecaller<Solution_> bestSolutionRecaller,
Termination<Solution_> termination) {
List<Phase<Solution_>> phaseList = new ArrayList<>(phaseConfigList.size());
boolean isPhaseSelected = false;
for (int phaseIndex = 0; phaseIndex < phaseConfigList.size(); phaseIndex++) {
PhaseConfig phaseConfig = phaseConfigList.get(phaseIndex);
var phaseConfig = phaseConfigList.get(phaseIndex);
if (phaseIndex > 0) {
PhaseConfig previousPhaseConfig = phaseConfigList.get(phaseIndex - 1);
if (!canTerminate(previousPhaseConfig)) {
Expand All @@ -55,8 +56,20 @@ static <Solution_> List<Phase<Solution_>> buildPhases(List<PhaseConfig> phaseCon
+ "without a configured termination (" + previousPhaseConfig + ").");
}
}
// The initialization phase can only be applied to construction heuristics or custom phases
var isConstructionOrCustomPhase = ConstructionHeuristicPhaseConfig.class.isAssignableFrom(phaseConfig.getClass())
|| CustomPhaseConfig.class.isAssignableFrom(phaseConfig.getClass());
// The next phase must be a local search
var isNextPhaseLocalSearch = phaseIndex + 1 < phaseConfigList.size()
&& LocalSearchPhaseConfig.class.isAssignableFrom(phaseConfigList.get(phaseIndex + 1).getClass());
PhaseFactory<Solution_> phaseFactory = PhaseFactory.create(phaseConfig);
Phase<Solution_> phase = phaseFactory.buildPhase(phaseIndex, configPolicy, bestSolutionRecaller, termination);
var phase = phaseFactory.buildPhase(phaseIndex,
!isPhaseSelected && isConstructionOrCustomPhase && isNextPhaseLocalSearch, configPolicy,
bestSolutionRecaller, termination);
// Ensure only one initialization phase is set
if (!isPhaseSelected && isConstructionOrCustomPhase && isNextPhaseLocalSearch) {
isPhaseSelected = true;
}
phaseList.add(phase);
}
return phaseList;
Expand All @@ -72,6 +85,7 @@ static boolean canTerminate(PhaseConfig phaseConfig) {
return (terminationConfig != null && terminationConfig.isConfigured());
}

Phase<Solution_> buildPhase(int phaseIndex, HeuristicConfigPolicy<Solution_> solverConfigPolicy,
BestSolutionRecaller<Solution_> bestSolutionRecaller, Termination<Solution_> solverTermination);
Phase<Solution_> buildPhase(int phaseIndex, boolean triggerFirstInitializedSolutionEvent,
HeuristicConfigPolicy<Solution_> solverConfigPolicy, BestSolutionRecaller<Solution_> bestSolutionRecaller,
Termination<Solution_> solverTermination);
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,10 @@ public static final class Builder<Solution_> extends AbstractPhase.Builder<Solut

private final List<CustomPhaseCommand<Solution_>> customPhaseCommandList;

public Builder(int phaseIndex, String logIndentation, Termination<Solution_> phaseTermination,
public Builder(int phaseIndex, boolean triggerFirstInitializedSolutionEvent, String logIndentation,
Termination<Solution_> phaseTermination,
List<CustomPhaseCommand<Solution_>> customPhaseCommandList) {
super(phaseIndex, logIndentation, phaseTermination);
super(phaseIndex, triggerFirstInitializedSolutionEvent, logIndentation, phaseTermination);
this.customPhaseCommandList = List.copyOf(customPhaseCommandList);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ public DefaultCustomPhaseFactory(CustomPhaseConfig phaseConfig) {
}

@Override
public CustomPhase<Solution_> buildPhase(int phaseIndex, HeuristicConfigPolicy<Solution_> solverConfigPolicy,
BestSolutionRecaller<Solution_> bestSolutionRecaller, Termination<Solution_> solverTermination) {
public CustomPhase<Solution_> buildPhase(int phaseIndex, boolean triggerFirstInitializedSolutionEvent,
HeuristicConfigPolicy<Solution_> solverConfigPolicy, BestSolutionRecaller<Solution_> bestSolutionRecaller,
Termination<Solution_> solverTermination) {
HeuristicConfigPolicy<Solution_> phaseConfigPolicy = solverConfigPolicy.createPhaseConfigPolicy();
if (ConfigUtils.isEmptyCollection(phaseConfig.getCustomPhaseCommandClassList())
&& ConfigUtils.isEmptyCollection(phaseConfig.getCustomPhaseCommandList())) {
Expand All @@ -43,7 +44,8 @@ public CustomPhase<Solution_> buildPhase(int phaseIndex, HeuristicConfigPolicy<S
customPhaseCommandList_.addAll((Collection) phaseConfig.getCustomPhaseCommandList());
}
DefaultCustomPhase.Builder<Solution_> builder =
new DefaultCustomPhase.Builder<>(phaseIndex, solverConfigPolicy.getLogIndentation(),
new DefaultCustomPhase.Builder<>(phaseIndex, triggerFirstInitializedSolutionEvent,
solverConfigPolicy.getLogIndentation(),
buildPhaseTermination(phaseConfigPolicy, solverTermination), customPhaseCommandList_);
EnvironmentMode environmentMode = phaseConfigPolicy.getEnvironmentMode();
if (environmentMode.isNonIntrusiveFullAsserted()) {
Expand Down
Loading
Loading