Skip to content

Commit

Permalink
feat: add initialization property and lifecycle listener
Browse files Browse the repository at this point in the history
  • Loading branch information
zepfred committed Jun 24, 2024
1 parent d776e12 commit 30d23a2
Show file tree
Hide file tree
Showing 21 changed files with 155 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ default SolverJobBuilder<Solution_, ProblemId_> withProblem(Solution_ problem) {
SolverJobBuilder<Solution_, ProblemId_>
withFinalBestSolutionConsumer(Consumer<? super Solution_> finalBestSolutionConsumer);

/**
* Sets the initialized solution consumer, which is called before starting the first
* {@link ai.timefold.solver.core.impl.localsearch.LocalSearchPhase} phase.
*
* @param initializedSolutionConsumer never null, called only once before starting the first Local Search phase
* @return this, never null
*/
SolverJobBuilder<Solution_, ProblemId_>
withInitializedSolutionConsumer(Consumer<? super Solution_> initializedSolutionConsumer);

/**
* Sets the custom exception handler.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ public static class Builder<Solution_> extends AbstractPhase.Builder<Solution_>
private final EntityPlacer<Solution_> entityPlacer;
private final ConstructionHeuristicDecider<Solution_> decider;

public Builder(int phaseIndex, String logIndentation, Termination<Solution_> phaseTermination,
public Builder(int phaseIndex, boolean initializationPhase, String logIndentation, Termination<Solution_> phaseTermination,
EntityPlacer<Solution_> entityPlacer, ConstructionHeuristicDecider<Solution_> decider) {
super(phaseIndex, logIndentation, phaseTermination);
super(phaseIndex, initializationPhase, logIndentation, phaseTermination);
this.entityPlacer = entityPlacer;
this.decider = decider;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ public DefaultConstructionHeuristicPhaseFactory(ConstructionHeuristicPhaseConfig
}

@Override
public ConstructionHeuristicPhase<Solution_> buildPhase(int phaseIndex, HeuristicConfigPolicy<Solution_> solverConfigPolicy,
BestSolutionRecaller<Solution_> bestSolutionRecaller, Termination<Solution_> solverTermination) {
public ConstructionHeuristicPhase<Solution_> buildPhase(int phaseIndex, boolean initializationPhase,
HeuristicConfigPolicy<Solution_> solverConfigPolicy, BestSolutionRecaller<Solution_> bestSolutionRecaller,
Termination<Solution_> solverTermination) {
ConstructionHeuristicType constructionHeuristicType_ = Objects.requireNonNullElse(
phaseConfig.getConstructionHeuristicType(),
ConstructionHeuristicType.ALLOCATE_ENTITY_FROM_QUEUE);
Expand All @@ -71,6 +72,7 @@ public ConstructionHeuristicPhase<Solution_> buildPhase(int phaseIndex, Heuristi

DefaultConstructionHeuristicPhase.Builder<Solution_> builder = new DefaultConstructionHeuristicPhase.Builder<>(
phaseIndex,
initializationPhase,
solverConfigPolicy.getLogIndentation(),
phaseTermination,
entityPlacer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,11 @@ public static class Builder<Solution_> extends AbstractPhase.Builder<Solution_>
private boolean assertWorkingSolutionScoreFromScratch = false;
private boolean assertExpectedWorkingSolutionScore = false;

public Builder(int phaseIndex, String logIndentation, Termination<Solution_> phaseTermination,
Comparator<ExhaustiveSearchNode> nodeComparator, EntitySelector<Solution_> entitySelector,
public Builder(int phaseIndex, boolean initializationPhase, String logIndentation,
Termination<Solution_> phaseTermination, Comparator<ExhaustiveSearchNode> nodeComparator,
EntitySelector<Solution_> entitySelector,
ExhaustiveSearchDecider<Solution_> decider) {
super(phaseIndex, logIndentation, phaseTermination);
super(phaseIndex, initializationPhase, logIndentation, phaseTermination);
this.nodeComparator = nodeComparator;
this.entitySelector = entitySelector;
this.decider = decider;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ public DefaultExhaustiveSearchPhaseFactory(ExhaustiveSearchPhaseConfig phaseConf
}

@Override
public ExhaustiveSearchPhase<Solution_> buildPhase(int phaseIndex, HeuristicConfigPolicy<Solution_> solverConfigPolicy,
BestSolutionRecaller<Solution_> bestSolutionRecaller, Termination<Solution_> solverTermination) {
public ExhaustiveSearchPhase<Solution_> buildPhase(int phaseIndex, boolean initializationPhase,
HeuristicConfigPolicy<Solution_> solverConfigPolicy, BestSolutionRecaller<Solution_> bestSolutionRecaller,
Termination<Solution_> solverTermination) {
ExhaustiveSearchType exhaustiveSearchType_ = Objects.requireNonNullElse(
phaseConfig.getExhaustiveSearchType(),
ExhaustiveSearchType.BRANCH_AND_BOUND);
Expand Down Expand Up @@ -83,7 +84,7 @@ EntitySelectorFactory.<Solution_> create(entitySelectorConfig_)
.buildEntitySelector(phaseConfigPolicy, SelectionCacheType.PHASE, SelectionOrder.ORIGINAL);

DefaultExhaustiveSearchPhase.Builder<Solution_> builder = new DefaultExhaustiveSearchPhase.Builder<>(phaseIndex,
solverConfigPolicy.getLogIndentation(), phaseTermination,
initializationPhase, solverConfigPolicy.getLogIndentation(), phaseTermination,
nodeExplorationType_.buildNodeComparator(scoreBounderEnabled), entitySelector, buildDecider(phaseConfigPolicy,
entitySelector, bestSolutionRecaller, phaseTermination, scoreBounderEnabled));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ public static class Builder<Solution_> extends AbstractPhase.Builder<Solution_>

public Builder(int phaseIndex, String logIndentation, Termination<Solution_> phaseTermination,
LocalSearchDecider<Solution_> decider) {
super(phaseIndex, logIndentation, phaseTermination);
// By definition, the Local Search does not return an initialized solution
super(phaseIndex, false, logIndentation, phaseTermination);
this.decider = decider;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ public DefaultLocalSearchPhaseFactory(LocalSearchPhaseConfig phaseConfig) {
}

@Override
public LocalSearchPhase<Solution_> buildPhase(int phaseIndex, HeuristicConfigPolicy<Solution_> solverConfigPolicy,
BestSolutionRecaller<Solution_> bestSolutionRecaller, Termination<Solution_> solverTermination) {
public LocalSearchPhase<Solution_> buildPhase(int phaseIndex, boolean initializationPhase,
HeuristicConfigPolicy<Solution_> solverConfigPolicy, BestSolutionRecaller<Solution_> bestSolutionRecaller,
Termination<Solution_> solverTermination) {
HeuristicConfigPolicy<Solution_> phaseConfigPolicy = solverConfigPolicy.createPhaseConfigPolicy();
Termination<Solution_> phaseTermination = buildPhaseTermination(phaseConfigPolicy, solverTermination);
DefaultLocalSearchPhase.Builder<Solution_> builder =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ public DefaultPartitionedSearchPhaseFactory(PartitionedSearchPhaseConfig phaseCo
}

@Override
public PartitionedSearchPhase<Solution_> buildPhase(int phaseIndex, HeuristicConfigPolicy<Solution_> solverConfigPolicy,
BestSolutionRecaller<Solution_> bestSolutionRecaller, Termination<Solution_> solverTermination) {
public PartitionedSearchPhase<Solution_> buildPhase(int phaseIndex, boolean initializationPhase,
HeuristicConfigPolicy<Solution_> solverConfigPolicy, BestSolutionRecaller<Solution_> bestSolutionRecaller,
Termination<Solution_> solverTermination) {
return TimefoldSolverEnterpriseService.loadOrFail(TimefoldSolverEnterpriseService.Feature.PARTITIONED_SEARCH)
.buildPartitionedSearch(phaseIndex, phaseConfig, solverConfigPolicy, solverTermination,
this::buildPhaseTermination);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public abstract class AbstractPhase<Solution_> implements Phase<Solution_> {
protected final boolean assertStepScoreFromScratch;
protected final boolean assertExpectedStepScore;
protected final boolean assertShadowVariablesAreNotStaleAfterStep;
protected final boolean initializationPhase;

/** Used for {@link #addPhaseLifecycleListener(PhaseLifecycleListener)}. */
protected PhaseLifecycleSupport<Solution_> phaseLifecycleSupport = new PhaseLifecycleSupport<>();
Expand All @@ -45,6 +46,7 @@ protected AbstractPhase(Builder<Solution_> builder) {
assertStepScoreFromScratch = builder.assertStepScoreFromScratch;
assertExpectedStepScore = builder.assertExpectedStepScore;
assertShadowVariablesAreNotStaleAfterStep = builder.assertShadowVariablesAreNotStaleAfterStep;
initializationPhase = builder.initializationPhase;
}

public int getPhaseIndex() {
Expand Down Expand Up @@ -77,6 +79,11 @@ public boolean isAssertShadowVariablesAreNotStaleAfterStep() {

public abstract String getPhaseTypeString();

@Override
public boolean isInitializationPhase() {
return initializationPhase;
}

// ************************************************************************
// Lifecycle methods
// ************************************************************************
Expand Down Expand Up @@ -209,15 +216,18 @@ but planning list variable (%s) has (%d) unexpected unassigned values.
protected abstract static class Builder<Solution_> {

private final int phaseIndex;
private final boolean initializationPhase;
private final String logIndentation;
private final Termination<Solution_> phaseTermination;

private boolean assertStepScoreFromScratch = false;
private boolean assertExpectedStepScore = false;
private boolean assertShadowVariablesAreNotStaleAfterStep = false;

protected Builder(int phaseIndex, String logIndentation, Termination<Solution_> phaseTermination) {
protected Builder(int phaseIndex, boolean initializationPhase, String logIndentation,
Termination<Solution_> phaseTermination) {
this.phaseIndex = phaseIndex;
this.initializationPhase = initializationPhase;
this.logIndentation = logIndentation;
this.phaseTermination = phaseTermination;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public void solve(SolverScope<Solution_> solverScope) {
public static class Builder<Solution_> extends AbstractPhase.Builder<Solution_> {

public Builder(int phaseIndex, String logIndentation, Termination<Solution_> phaseTermination) {
super(phaseIndex, logIndentation, phaseTermination);
super(phaseIndex, false, logIndentation, phaseTermination);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ public NoChangePhaseFactory(NoChangePhaseConfig phaseConfig) {
}

@Override
public NoChangePhase<Solution_> buildPhase(int phaseIndex, HeuristicConfigPolicy<Solution_> solverConfigPolicy,
BestSolutionRecaller<Solution_> bestSolutionRecaller, Termination<Solution_> solverTermination) {
public NoChangePhase<Solution_> buildPhase(int phaseIndex, boolean initializationPhase,
HeuristicConfigPolicy<Solution_> solverConfigPolicy, BestSolutionRecaller<Solution_> bestSolutionRecaller,
Termination<Solution_> solverTermination) {
HeuristicConfigPolicy<Solution_> phaseConfigPolicy = solverConfigPolicy.createPhaseConfigPolicy();
return new NoChangePhase.Builder<>(phaseIndex, solverConfigPolicy.getLogIndentation(),
buildPhaseTermination(phaseConfigPolicy, solverTermination)).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,10 @@ public interface Phase<Solution_> extends PhaseLifecycleListener<Solution_> {

void solve(SolverScope<Solution_> solverScope);

/**
* Checks if a phase returns an initialized solution.
*
* @return true if the phase returns an initialized solution.
*/
boolean isInitializationPhase();
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ static <Solution_> List<Phase<Solution_>> buildPhases(List<PhaseConfig> phaseCon
Termination<Solution_> termination) {
List<Phase<Solution_>> phaseList = new ArrayList<>(phaseConfigList.size());
for (int phaseIndex = 0; phaseIndex < phaseConfigList.size(); phaseIndex++) {
PhaseConfig phaseConfig = phaseConfigList.get(phaseIndex);
var phaseConfig = phaseConfigList.get(phaseIndex);
if (phaseIndex > 0) {
PhaseConfig previousPhaseConfig = phaseConfigList.get(phaseIndex - 1);
if (!canTerminate(previousPhaseConfig)) {
Expand All @@ -55,8 +55,13 @@ static <Solution_> List<Phase<Solution_>> buildPhases(List<PhaseConfig> phaseCon
+ "without a configured termination (" + previousPhaseConfig + ").");
}
}
var isConstructionOrCustomPhase = ConstructionHeuristicPhaseConfig.class.isAssignableFrom(phaseConfig.getClass())
|| CustomPhaseConfig.class.isAssignableFrom(phaseConfig.getClass());
var isNextPhaseLocalSearch = phaseIndex + 1 < phaseConfigList.size()
&& LocalSearchPhaseConfig.class.isAssignableFrom(phaseConfigList.get(phaseIndex + 1).getClass());
PhaseFactory<Solution_> phaseFactory = PhaseFactory.create(phaseConfig);
Phase<Solution_> phase = phaseFactory.buildPhase(phaseIndex, configPolicy, bestSolutionRecaller, termination);
var phase = phaseFactory.buildPhase(phaseIndex, isConstructionOrCustomPhase && isNextPhaseLocalSearch, configPolicy,
bestSolutionRecaller, termination);
phaseList.add(phase);
}
return phaseList;
Expand All @@ -72,6 +77,7 @@ static boolean canTerminate(PhaseConfig phaseConfig) {
return (terminationConfig != null && terminationConfig.isConfigured());
}

Phase<Solution_> buildPhase(int phaseIndex, HeuristicConfigPolicy<Solution_> solverConfigPolicy,
BestSolutionRecaller<Solution_> bestSolutionRecaller, Termination<Solution_> solverTermination);
Phase<Solution_> buildPhase(int phaseIndex, boolean initializationPhase,
HeuristicConfigPolicy<Solution_> solverConfigPolicy, BestSolutionRecaller<Solution_> bestSolutionRecaller,
Termination<Solution_> solverTermination);
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ public static final class Builder<Solution_> extends AbstractPhase.Builder<Solut

private final List<CustomPhaseCommand<Solution_>> customPhaseCommandList;

public Builder(int phaseIndex, String logIndentation, Termination<Solution_> phaseTermination,
public Builder(int phaseIndex, boolean initializationPhase, String logIndentation, Termination<Solution_> phaseTermination,
List<CustomPhaseCommand<Solution_>> customPhaseCommandList) {
super(phaseIndex, logIndentation, phaseTermination);
super(phaseIndex, initializationPhase, logIndentation, phaseTermination);
this.customPhaseCommandList = List.copyOf(customPhaseCommandList);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ public DefaultCustomPhaseFactory(CustomPhaseConfig phaseConfig) {
}

@Override
public CustomPhase<Solution_> buildPhase(int phaseIndex, HeuristicConfigPolicy<Solution_> solverConfigPolicy,
BestSolutionRecaller<Solution_> bestSolutionRecaller, Termination<Solution_> solverTermination) {
public CustomPhase<Solution_> buildPhase(int phaseIndex, boolean initializationPhase,
HeuristicConfigPolicy<Solution_> solverConfigPolicy, BestSolutionRecaller<Solution_> bestSolutionRecaller,
Termination<Solution_> solverTermination) {
HeuristicConfigPolicy<Solution_> phaseConfigPolicy = solverConfigPolicy.createPhaseConfigPolicy();
if (ConfigUtils.isEmptyCollection(phaseConfig.getCustomPhaseCommandClassList())
&& ConfigUtils.isEmptyCollection(phaseConfig.getCustomPhaseCommandList())) {
Expand All @@ -43,7 +44,7 @@ public CustomPhase<Solution_> buildPhase(int phaseIndex, HeuristicConfigPolicy<S
customPhaseCommandList_.addAll((Collection) phaseConfig.getCustomPhaseCommandList());
}
DefaultCustomPhase.Builder<Solution_> builder =
new DefaultCustomPhase.Builder<>(phaseIndex, solverConfigPolicy.getLogIndentation(),
new DefaultCustomPhase.Builder<>(phaseIndex, initializationPhase, solverConfigPolicy.getLogIndentation(),
buildPhaseTermination(phaseConfigPolicy, solverTermination), customPhaseCommandList_);
EnvironmentMode environmentMode = phaseConfigPolicy.getEnvironmentMode();
if (environmentMode.isNonIntrusiveFullAsserted()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public abstract class AbstractPhaseScope<Solution_> {
protected final SolverScope<Solution_> solverScope;
protected final int phaseIndex;
protected final boolean phaseSendingBestSolutionEvents;

protected final boolean initializationPhase;
protected Long startingSystemTimeMillis;
protected Long startingScoreCalculationCount;
protected Score startingScore;
Expand All @@ -39,6 +39,14 @@ protected AbstractPhaseScope(SolverScope<Solution_> solverScope, int phaseIndex)
this(solverScope, phaseIndex, true);
}

/**
* As defined by #AbstractPhaseScope(SolverScope, int, boolean, boolean)
* with the initializationPhase parameter set to false.
*/
protected AbstractPhaseScope(SolverScope<Solution_> solverScope, int phaseIndex, boolean phaseSendingBestSolutionEvents) {
this(solverScope, phaseIndex, phaseSendingBestSolutionEvents, false);
}

/**
*
* @param solverScope never null
Expand All @@ -47,11 +55,13 @@ protected AbstractPhaseScope(SolverScope<Solution_> solverScope, int phaseIndex)
* or none at all;
* this is typical for construction heuristics,
* whose result only matters when it reached its natural end.
* @param initializationPhase set to false if the phase does not return the initialized solution
*/
protected AbstractPhaseScope(SolverScope<Solution_> solverScope, int phaseIndex, boolean phaseSendingBestSolutionEvents) {
protected AbstractPhaseScope(SolverScope<Solution_> solverScope, int phaseIndex, boolean phaseSendingBestSolutionEvents, boolean initializationPhase) {
this.solverScope = solverScope;
this.phaseIndex = phaseIndex;
this.phaseSendingBestSolutionEvents = phaseSendingBestSolutionEvents;
this.initializationPhase = initializationPhase;
}

public SolverScope<Solution_> getSolverScope() {
Expand All @@ -66,6 +76,10 @@ public boolean isPhaseSendingBestSolutionEvents() {
return phaseSendingBestSolutionEvents;
}

public boolean isInitializationPhase() {
return initializationPhase;
}

public Long getStartingSystemTimeMillis() {
return startingSystemTimeMillis;
}
Expand Down
Loading

0 comments on commit 30d23a2

Please sign in to comment.