Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redesign MemoTable, CostEstimator, PlanEnumerator for Optimal Plan Selection #2147

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 282 additions & 0 deletions src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.hops.fedplanner;

import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.ArrayList;
import java.util.Map;
import java.util.HashSet;
import java.util.Set;

/**
* A Memoization Table for managing federated plans (FedPlan) based on combinations of Hops and fedOutTypes.
* This table stores and manages different execution plan variants for each Hop and fedOutType combination,
* facilitating the optimization of federated execution plans.
*/
public class FederatedMemoTable {
// Maps Hop ID and fedOutType pairs to their plan variants
private final Map<Pair<Long, FederatedOutput>, FedPlanVariants> hopMemoTable = new HashMap<>();

/**
* Adds a new federated plan to the memo table.
* Creates a new variant list if none exists for the given Hop and fedOutType.
*
* @param hop The Hop node
* @param fedOutType The federated output type
* @param planChilds List of child plan references
* @return The newly created FedPlan
*/
public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType, List<Pair<Long, FederatedOutput>> planChilds) {
long hopID = hop.getHopID();
FedPlanVariants fedPlanVariantList;

if (contains(hopID, fedOutType)) {
fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType));
} else {
fedPlanVariantList = new FedPlanVariants(hop, fedOutType);
hopMemoTable.put(new ImmutablePair<>(hopID, fedOutType), fedPlanVariantList);
}

FedPlan newPlan = new FedPlan(planChilds, fedPlanVariantList);
fedPlanVariantList.addFedPlan(newPlan);

return newPlan;
}

/**
* Retrieves the minimum cost child plan considering the parent's output type.
* The cost is calculated using getParentViewCost to account for potential type mismatches.
*/
public FedPlan getMinCostChildFedPlan(long childHopID, FederatedOutput childFedOutType) {
FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(childHopID, childFedOutType));
return fedPlanVariantList._fedPlanVariants.stream()
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
.orElse(null);
}

public FedPlanVariants getFedPlanVariants(long hopID, FederatedOutput fedOutType) {
return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType));
}

/**
* Checks if the memo table contains an entry for a given Hop and fedOutType.
*
* @param hopID The Hop ID.
* @param fedOutType The associated fedOutType.
* @return True if the entry exists, false otherwise.
*/
public boolean contains(long hopID, FederatedOutput fedOutType) {
return hopMemoTable.containsKey(new ImmutablePair<>(hopID, fedOutType));
}

/**
* Prunes all entries in the memo table, retaining only the minimum-cost
* FedPlan for each entry.
*/
public void pruneMemoTable() {
for (Map.Entry<Pair<Long, FederatedOutput>, FedPlanVariants> entry : hopMemoTable.entrySet()) {
List<FedPlan> fedPlanList = entry.getValue().getFedPlanVariants();
if (fedPlanList.size() > 1) {
// Find the FedPlan with the minimum cost
FedPlan minCostPlan = fedPlanList.stream()
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
.orElse(null);

// Retain only the minimum cost plan
fedPlanList.clear();
fedPlanList.add(minCostPlan);
}
}
}

/**
* Recursively prints a tree representation of the DAG starting from the given root FedPlan.
* Includes information about hopID, fedOutType, TotalCost, SelfCost, and NetCost for each node.
*
* @param rootFedPlan The starting point FedPlan to print
*/
public void printFedPlanTree(FedPlan rootFedPlan) {
Set<FedPlan> visited = new HashSet<>();
printFedPlanTreeRecursive(rootFedPlan, visited, 0, true);
}

/**
* Helper method to recursively print the FedPlan tree.
*
* @param plan The current FedPlan to print
* @param visited Set to keep track of visited FedPlans (prevents cycles)
* @param depth The current depth level for indentation
* @param isLast Whether this node is the last child of its parent
*/
private void printFedPlanTreeRecursive(FedPlan plan, Set<FedPlan> visited, int depth, boolean isLast) {
if (plan == null || visited.contains(plan)) {
return;
}

visited.add(plan);

Hop hop = plan.getHopRef();
StringBuilder sb = new StringBuilder();

// Add FedPlan information
sb.append(String.format("(%d) ", plan.getHopRef().getHopID()))
.append(plan.getHopRef().getOpString())
.append(" [")
.append(plan.getFedOutType())
.append("]");

StringBuilder childs = new StringBuilder();
childs.append(" (");
boolean childAdded = false;
for( Hop input : hop.getInput()){
childs.append(childAdded?",":"");
childs.append(input.getHopID());
childAdded = true;
}
childs.append(")");
if( childAdded )
sb.append(childs.toString());


sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}",
plan.getTotalCost(),
plan.getSelfCost(),
plan.getNetTransferCost()));

// Add matrix characteristics
sb.append(" [")
.append(hop.getDim1()).append(", ")
.append(hop.getDim2()).append(", ")
.append(hop.getBlocksize()).append(", ")
.append(hop.getNnz());

if (hop.getUpdateType().isInPlace()) {
sb.append(", ").append(hop.getUpdateType().toString().toLowerCase());
}
sb.append("]");

// Add memory estimates
sb.append(" [")
.append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ")
.append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ")
.append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ")
.append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]");

// Add reblock and checkpoint requirements
if (hop.requiresReblock() && hop.requiresCheckpoint()) {
sb.append(" [rblk, chkpt]");
} else if (hop.requiresReblock()) {
sb.append(" [rblk]");
} else if (hop.requiresCheckpoint()) {
sb.append(" [chkpt]");
}

// Add execution type
if (hop.getExecType() != null) {
sb.append(", ").append(hop.getExecType());
}

System.out.println(sb);

// Process child nodes
List<Pair<Long, FederatedOutput>> childRefs = plan.getChildFedPlans();
for (int i = 0; i < childRefs.size(); i++) {
Pair<Long, FederatedOutput> childRef = childRefs.get(i);
FedPlanVariants childVariants = getFedPlanVariants(childRef.getLeft(), childRef.getRight());
if (childVariants == null || childVariants.getFedPlanVariants().isEmpty())
continue;

boolean isLastChild = (i == childRefs.size() - 1);
for (FedPlan childPlan : childVariants.getFedPlanVariants()) {
printFedPlanTreeRecursive(childPlan, visited, depth + 1, isLastChild);
}
}
}

/**
* Represents a collection of federated execution plan variants for a specific Hop.
* Contains cost information and references to the associated plans.
*/
public static class FedPlanVariants {
protected final Hop hopRef; // Reference to the associated Hop
protected double selfCost; // Current execution cost (compute + memory access)
protected double netTransferCost; // Network transfer cost
private final FederatedOutput fedOutType; // Output type (FOUT/LOUT)
protected List<FedPlan> _fedPlanVariants; // List of plan variants

public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) {
this.hopRef = hopRef;
this.fedOutType = fedOutType;
this.selfCost = 0;
this.netTransferCost = 0;
this._fedPlanVariants = new ArrayList<>();
}

public int size() {return _fedPlanVariants.size();}
public void addFedPlan(FedPlan fedPlan) {_fedPlanVariants.add(fedPlan);}
public List<FedPlan> getFedPlanVariants() {return _fedPlanVariants;}
}

/**
* Represents a single federated execution plan with its associated costs and dependencies.
* Contains:
* 1. selfCost: Cost of current hop (compute + input/output memory access)
* 2. totalCost: Cumulative cost including this plan and all child plans
* 3. netTransferCost: Network transfer cost for this plan to parent plan.
*/
public static class FedPlan {
private double totalCost; // Total cost including child plans
private final FedPlanVariants fedPlanVariants; // Reference to variant list
private final List<Pair<Long, FederatedOutput>> childFedPlans; // Child plan references

public FedPlan(List<Pair<Long, FederatedOutput>> childFedPlans, FedPlanVariants fedPlanVariants) {
this.totalCost = 0;
this.childFedPlans = childFedPlans;
this.fedPlanVariants = fedPlanVariants;
}

public void setTotalCost(double totalCost) {this.totalCost = totalCost;}
public void setSelfCost(double selfCost) {fedPlanVariants.selfCost = selfCost;}
public void setNetTransferCost(double netTransferCost) {fedPlanVariants.netTransferCost = netTransferCost;}

public Hop getHopRef() {return fedPlanVariants.hopRef;}
public FederatedOutput getFedOutType() {return fedPlanVariants.fedOutType;}
public double getTotalCost() {return totalCost;}
public double getSelfCost() {return fedPlanVariants.selfCost;}
private double getNetTransferCost() {return fedPlanVariants.netTransferCost;}
public List<Pair<Long, FederatedOutput>> getChildFedPlans() {return childFedPlans;}

/**
* Calculates the conditional network transfer cost based on output type compatibility.
* Returns 0 if output types match, otherwise returns the network transfer cost.
*/
public double getCondNetTransferCost(FederatedOutput parentFedOutType) {
if (parentFedOutType == getFedOutType()) return 0;
return fedPlanVariants.netTransferCost;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package org.apache.sysds.hops.fedplanner;
import java.util.ArrayList;
import java.util.List;
import java.util.Comparator;
import java.util.Objects;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan;
import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;

/**
* Enumerates and evaluates all possible federated execution plans for a given Hop DAG.
* Works with FederatedMemoTable to store plan variants and FederatedPlanCostEstimator
* to compute their costs.
*/
public class FederatedPlanCostEnumerator {
/**
* Entry point for federated plan enumeration. Creates a memo table and returns
* the minimum cost plan for the entire DAG.
*/
public static FedPlan enumerateFederatedPlanCost(Hop rootHop, boolean printTree) {
// Create new memo table to store all plan variants
FederatedMemoTable memoTable = new FederatedMemoTable();

// Recursively enumerate all possible plans
enumerateFederatedPlanCost(rootHop, memoTable);

// Return the minimum cost plan for the root node
FedPlan optimalPlan = getMinCostRootFedPlan(rootHop.getHopID(), memoTable);
memoTable.pruneMemoTable();
if (printTree) memoTable.printFedPlanTree(optimalPlan);

return optimalPlan;
}

/**
* Recursively enumerates all possible federated execution plans for a Hop DAG.
* For each node:
* 1. First processes all input nodes recursively if not already processed
* 2. Generates all possible combinations of federation types (FOUT/LOUT) for inputs
* 3. Creates and evaluates both FOUT and LOUT variants for current node with each input combination
*
* The enumeration uses a bottom-up approach where:
* - Each input combination is represented by a binary number (i)
* - Bit j in i determines whether input j is FOUT (1) or LOUT (0)
* - Total number of combinations is 2^numInputs
*/
private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoTable) {
int numInputs = hop.getInput().size();

// Process all input nodes first if not already in memo table
for (Hop inputHop : hop.getInput()) {
if (!memoTable.contains(inputHop.getHopID(), FederatedOutput.FOUT)
&& !memoTable.contains(inputHop.getHopID(), FederatedOutput.LOUT)) {
enumerateFederatedPlanCost(inputHop, memoTable);
}
}

// Generate all possible input combinations using binary representation
// i represents a specific combination of FOUT/LOUT for inputs
for (int i = 0; i < (1 << numInputs); i++) {
List<Pair<Long, FederatedOutput>> planChilds = new ArrayList<>();

// For each input, determine if it should be FOUT or LOUT based on bit j in i
for (int j = 0; j < numInputs; j++) {
Hop inputHop = hop.getInput().get(j);
// If bit j is set (1), use FOUT; otherwise use LOUT
FederatedOutput childType = ((i & (1 << j)) != 0) ?
FederatedOutput.FOUT : FederatedOutput.LOUT;
planChilds.add(Pair.of(inputHop.getHopID(), childType));
}

// Create and evaluate FOUT variant for current input combination
FedPlan fOutPlan = memoTable.addFedPlan(hop, FederatedOutput.FOUT, planChilds);
FederatedPlanCostEstimator.computeFederatedPlanCost(fOutPlan, memoTable);

// Create and evaluate LOUT variant for current input combination
FedPlan lOutPlan = memoTable.addFedPlan(hop, FederatedOutput.LOUT, planChilds);
FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, memoTable);
}
}

/**
* Returns the minimum cost plan for the root Hop, comparing both FOUT and LOUT variants.
* Used to select the final execution plan after enumeration.
*/
private static FedPlan getMinCostRootFedPlan(long HopID, FederatedMemoTable memoTable) {
FedPlanVariants fOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FederatedOutput.FOUT);
FedPlanVariants lOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FederatedOutput.LOUT);

FedPlan minFOutFedPlan = fOutFedPlanVariants._fedPlanVariants.stream()
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
.orElse(null);
FedPlan minlOutFedPlan = lOutFedPlanVariants._fedPlanVariants.stream()
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
.orElse(null);

if (Objects.requireNonNull(minFOutFedPlan).getTotalCost()
< Objects.requireNonNull(minlOutFedPlan).getTotalCost()) {
return minFOutFedPlan;
}
return minlOutFedPlan;
}

}
Loading
Loading