Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redesign MemoTable, CostEstimator, PlanEnumerator for Optimal Plan Selection #2147

Closed
wants to merge 7 commits into from

Conversation

min-guk
Copy link
Contributor

@min-guk min-guk commented Nov 28, 2024

I have implemented FederatedPlanCostEstimator, FederatedPlanCostEnumerator, and FederatedMemoTable.

However, this implementation differs significantly from the coding direction we discussed in the meeting, so we need detailed discussion about the implementation direction in the next meeting.

Please review my thoughts and advise if my understanding is correct.

1. FederatedMemoTable (MemoTable)

public class MemoTable {
        private final Map<Pair<Long, FTypes.FType>, List<FedPlan>> hopMemoTable = new HashMap<>();

	public static class FedPlan {
		@SuppressWarnings("unused")
		private final Hop hopRef;                       // The associated Hop object
		private final double cost;                      // Cost of this federated plan
		@SuppressWarnings("unused")
		private final List<Pair<Long, FType>> planRefs;	// References to dependent plans
	}
}

The previous FedPlan class structure had several issues:

  • A single <HopID, FederatedOutput> pair stored multiple FedPlans as a list in the MemoTable, redundantly storing the hopRef.
  • A single <HopID, FederatedOutput> pair had to calculate its computeCost and accessCost 2^(planRefs+1) times redundantly.
  • FedPlan did not store its own FederatedOutput
public class FederatedMemoTable {
    private final Map<Pair<Long, FederatedOutput>, FedPlanVariants> hopMemoTable = new HashMap<>();

    public static class FedPlanVariants {
        protected final Hop hopRef;         // Reference to the associated Hop
        protected double currentCost;       // Current execution cost (compute + memory access)
        protected double netTransferCost;   // Network transfer cost
        protected List<FedPlan> _fedPlanVariants;  
    }
    public static class FedPlan {
        private double cumulativeCost;                  // Total cost including child plans
        private final FederatedOutput fedOutType;       // Output type (FOUT/LOUT)
        private final FedPlanVariants fedPlanVariantList;  // Reference to variant list
        private List<Pair<Long, FederatedOutput>> metaChildFedPlans;  // Child plan references
        private List<FedPlan> selectedFedPlans;           // Selected child plans
    }

The key points of the redesigned FederatedMemoTable are as follows:

  • A single <HopID, FederatedOutput> pair has one FedPlanVariants, which stores and shares the redundant hopRef, currentCost, and netTransferCost with FedPlans stored in fedPlanVariants.
  • A single <HopID, FederatedOutput> pair calculates its computeCost and accessCost only once.
  • FedPlan stores its own FederatedOutput.

2. CostEstimator

    // Do not create and allocate any new FedPlan.
    // just calculate the cost for given fed plans.
    // cost of dependent fedplans in planRefs is already calculated.
    public static void computeFederatedPlanCost(FedPlan currentPlan, FederatedMemoTable memoTable){
        double cost = computeFederatedPlanCost(currentPlan.getHopRef());

        for (Pair<Long, FederatedOutput> planRefMeta: currentPlan.getPlanRefs()){
            FedPlan planRef = memoTable.getFedPlan(planRefMeta.getLeft(), planRefMeta.getRight());
            cost += planRef.getCost();

            if (currentPlan.getFedOutType() != planRef.getFedOutType()){
                cost += computeHopNetworkAccessCost(planRef.getHopRef().getOutputMemEstimate());
            }
        }
        currentPlan.setCost(cost);
    }

The previous CostEstimator also had several issues:

  • It calculates the currentHop's cost every time.
  • The Optimal FedPlan should minimize the total cost of compute, memory access, and network access.
  • However, the previous CostEstimator selects the ref plan with minimum cost excluding network cost, and then adds network cost afterward, so it cannot guarantee the minimum cost FedPlan.
    public static void computeFederatedPlanCost(FedPlan currentPlan, FederatedMemoTable memoTable) {
        double cumulativeCost = 0;
        Hop currentHop = currentPlan.getHopRef();

        // Step 1: Calculate current node costs if not already computed
        if (currentPlan.getCurrentCost() == 0) {
            // Compute cost for current node (computation + memory access)
            cumulativeCost = computeCurrentCost(currentHop);
            currentPlan.setCurrentCost(cumulativeCost);
            // Calculate potential network transfer cost if federation type changes
            currentPlan.setNetTransferCost(computeHopNetworkAccessCost(currentHop.getOutputMemEstimate()));
        } else {
            cumulativeCost = currentPlan.getCurrentCost();
        }
        
        // Step 2: Process each child plan and add their costs
        for (Pair<Long, FederatedOutput> planRefMeta : currentPlan.getMetaChildFedPlans()) {
            // Find minimum cost child plan considering federation type compatibility
            // Note: This approach might lead to suboptimal or wrong solutions when a child has multiple parents
            // because we're selecting child plans independently for each parent
            FedPlan planRef = memoTable.getMinCostChildFedPlan(
                    planRefMeta.getLeft(), planRefMeta.getRight(), currentPlan.getFedOutType());

            // Add child plan cost (includes network transfer cost if federation types differ)
            cumulativeCost += planRef.getParentViewCost(currentPlan.getFedOutType());
            
            // Store selected child plan
            // Note: Selected plan has minimum parent view cost, not minimum cumulative cost,
            // which means it highly unlikely to be found through simple pruning after enumeration
            currentPlan.putChildFedPlan(planRef);
        }
        
        // Step 3: Set final cumulative cost including current node
        currentPlan.setCumulativeCost(cumulativeCost);
    }

The key points of the redesigned CostEstimator are as follows:

  • It calculates the compute cost and access cost of currentHop only once per HopID.
  • When selecting the minimum cost ref plan, it selects the ref plan including network cost, ensuring minimum total cost.
  • It stores selected child plans in a list as pointers.
    • This is because when pruning all at once in the memotable later, we cannot calculate network cost without knowing the fOutType of each fedplan's parent fedplan, so we cannot identify the optimal cost plan. Therefore, pruning in the current MemoTable has been removed.

However, the current CostEstimator may cause two problems because it selects child plans based only on the cost of a single current plan and child plan:

  1. A child plan can have multiple parent plans, and different parent plans can select different child plans. Therefore, a child plan could form a non-existent fed plan with different fOutTypes.
  2. Since a child plan can have multiple parent plans, it should select the fOutType that minimizes the sum of costs of all parent plans referencing it. Otherwise, it may select a suboptimal plan.
  • We need to devise a new algorithm to solve these two problems.

3. FederatedPlanCostEnumerator

public class FederatedPlanCostEnumerator {
    public static FedPlan enumerateFederatedPlanCost(Hop rootHop) {              
        FederatedMemoTable memoTable = new FederatedMemoTable();
        enumerateFederatedPlanCost(rootHop, memoTable);
        return getMinCostRootFedPlan(rootHop.getHopID(), memoTable);
    }

    /**
     * Recursively enumerates all possible federated execution plans for a Hop DAG.
     * For each node:
     * 1. First processes all input nodes recursively if not already processed
     * 2. Generates all possible combinations of federation types (FOUT/LOUT) for inputs
     * 3. Creates and evaluates both FOUT and LOUT variants for current node with each input combination
     * 
     * The enumeration uses a bottom-up approach where:
     * - Each input combination is represented by a binary number (i)
     * - Bit j in i determines whether input j is FOUT (1) or LOUT (0)
     * - Total number of combinations is 2^numInputs
     */
    private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoTable) {              
        int numInputs = hop.getInput().size();

        // Process all input nodes first if not already in memo table
        for (Hop inputHop : hop.getInput()) {
            if (!memoTable.contains(inputHop.getHopID(), FederatedOutput.FOUT) 
                && !memoTable.contains(inputHop.getHopID(), FederatedOutput.LOUT)) {
                    enumerateFederatedPlanCost(inputHop, memoTable);
            }
        }

        // Generate all possible input combinations using binary representation
        // i represents a specific combination of FOUT/LOUT for inputs
        for (int i = 0; i < (1 << numInputs); i++) {
            List<Pair<Long, FederatedOutput>> planChilds = new ArrayList<>(); 

            // For each input, determine if it should be FOUT or LOUT based on bit j in i
            for (int j = 0; j < numInputs; j++) {
                Hop inputHop = hop.getInput().get(j);
                // If bit j is set (1), use FOUT; otherwise use LOUT
                FederatedOutput childType = ((i & (1 << j)) != 0) ?
                    FederatedOutput.FOUT : FederatedOutput.LOUT;
                planChilds.add(Pair.of(inputHop.getHopID(), childType));
            }
            
            // Create and evaluate FOUT variant for current input combination
            FedPlan fOutPlan = memoTable.addFedPlan(hop, FederatedOutput.FOUT, planChilds);
            FederatedPlanCostEstimator.computeFederatedPlanCost(fOutPlan, memoTable);

            // Create and evaluate LOUT variant for current input combination
            FedPlan lOutPlan = memoTable.addFedPlan(hop, FederatedOutput.LOUT, planChilds);
            FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, memoTable);
        }
    }
}
  • This implementation is based on the newly implemented FederatedPlanCostEstimator and FederatedMemoTable, following the direction we previously discussed.
  • I'm not sure how to create complex Hop DAGs similar to real scenarios in the test code. Could you please provide some reference test code that I can refer to?

@mboehm7
Copy link
Contributor

mboehm7 commented Dec 4, 2024

Thanks for the contribution @min-guk - as we just discussed please clean up the enumeration logic for arbitrary many inputs, and in a subsequent PR add a debugging print out of the memo table and its federated plans.

@min-guk
Copy link
Contributor Author

min-guk commented Dec 20, 2024

Updates

  1. Modified the selection process for child FedPlan to choose the plan with the minimum cost among those with the same fedOutType.
  2. Added functionality to output the DAG of the optimal FedPlan.
  3. Developed an integrated test code that combines the Memo Table, Cost Estimator, and Cost Enumerator instead of focusing on specific components.

Unchanged Features

  1. The FedPlan and FedVariant classes are retained to avoid redundant cost calculations.
  2. The Cost Enumerator already supports multiple child FedPlans (two or more), and no modifications were made.

Discussion Points

  1. In the current test code, the _outputMemEstimate is initialized to -1, leading to incorrect memory access and network cost calculations, which produce an invalid FedPlan. This needs to be fixed.
  2. If necessary, we may need to verify that all operations in the federated plans correctly set _outputMemEstimate or update the Cost Estimator to account for such exceptional cases.

@min-guk
Copy link
Contributor Author

min-guk commented Dec 20, 2024

└──Hop 1 [LOUT] (Total: 2.000, Self: 1.000, Net: -0.000)
       └─Hop 0 [FOUT] (Total: 1.000, Self: 1.000, Net: -0.000)
└──Hop 12 [LOUT] (Total: 2.000, Self: 1.000, Net: -0.000)
       └─Hop 11 [FOUT] (Total: 1.000, Self: 1.000, Net: -0.000)
└──Hop 23 [LOUT] (Total: 2.000, Self: 1.000, Net: -0.000)
       └─Hop 22 [FOUT] (Total: 1.000, Self: 1.000, Net: -0.000)
└──Hop 42 [LOUT] (Total: 28.000, Self: 10.000, Net: -0.000)
       └─Hop 41 [FOUT] (Total: 18.000, Self: 10.000, Net: -0.000)
              ├─Hop 33 [LOUT] (Total: 1.000, Self: 1.000, Net: -0.000)
              ├─Hop 34 [LOUT] (Total: 1.000, Self: 1.000, Net: -0.000)
              ├─Hop 35 [LOUT] (Total: 1.000, Self: 1.000, Net: -0.000)
              ├─Hop 36 [LOUT] (Total: 1.000, Self: 1.000, Net: -0.000)
              ├─Hop 37 [LOUT] (Total: 1.000, Self: 1.000, Net: -0.000)
              ├─Hop 38 [LOUT] (Total: 1.000, Self: 1.000, Net: -0.000)
              ├─Hop 39 [LOUT] (Total: 1.000, Self: 1.000, Net: -0.000)
              └─Hop 40 [LOUT] (Total: 1.000, Self: 1.000, Net: -0.000)
└──Hop 65 [LOUT] (Total: 28.000, Self: 10.000, Net: -0.000)
       └─Hop 64 [FOUT] (Total: 18.000, Self: 10.000, Net: -0.000)
              ├─Hop 56 [LOUT] (Total: 1.000, Self: 1.000, Net: -0.000)
              ├─Hop 57 [LOUT] (Total: 1.000, Self: 1.000, Net: -0.000)
              ├─Hop 58 [LOUT] (Total: 1.000, Self: 1.000, Net: -0.000)
              ├─Hop 59 [LOUT] (Total: 1.000, Self: 1.000, Net: -0.000)
              ├─Hop 60 [LOUT] (Total: 1.000, Self: 1.000, Net: -0.000)
              ├─Hop 61 [LOUT] (Total: 1.000, Self: 1.000, Net: -0.000)
              ├─Hop 62 [LOUT] (Total: 1.000, Self: 1.000, Net: -0.000)
              └─Hop 63 [LOUT] (Total: 1.000, Self: 1.000, Net: -0.000)

Process finished with exit code 0

Here is the output of the DAG from the TestCode for your reference.

@min-guk
Copy link
Contributor Author

min-guk commented Dec 20, 2024

Updates

As discussed in the meeting, I have updated the following components:

  • CostEnumeratorTest
  • printFedPlanTreeRecursive

Updated Output of printFedPlanTreeRecursive

The following is the updated result of printFedPlanTreeRecursive:

(14) u(print) [LOUT] (13) {Total: 911.0, Self: 1.0, Net: 0.0} [-1, -1, -1, -1] [0, 0, 0 -> 0MB]
(13) ua(+RC) [LOUT] (12) {Total: 910.0, Self: 400.0, Net: 0.0} [0, 0, -1, -1] [0, 0, 0 -> 0MB], CP
(12) u(sqrt) [LOUT] (11) {Total: 510.0, Self: 200.0, Net: 0.0} [10, 10, 1000, -1] [0, 0, 0 -> 0MB], CP
(11) b(+) [LOUT] (1,10) {Total: 310.0, Self: 100.0, Net: 0.0} [10, 10, 1000, -1] [0, 0, 0 -> 0MB], CP
(1) LiteralOp 7 [LOUT] {Total: 1.0, Self: 1.0, Net: 0.0} [0, 0, -1, -1] [0, 0, 0 -> 0MB]
(10) b(^) [LOUT] (8,9) {Total: 209.0, Self: 100.0, Net: 0.0} [10, 10, 1000, 100] [0, 0, 0 -> 0MB], CP
(8) dg(rand) [LOUT] (0,1,2,3,1,0,6,6) {Total: 108.0, Self: 100.0, Net: 0.0} [10, 10, 1000, 100] [0, 0, 0 -> 0MB], CP
(0) LiteralOp 1.0 [LOUT] {Total: 1.0, Self: 1.0, Net: 0.0} [0, 0, -1, -1] [0, 0, 0 -> 0MB]
(2) LiteralOp uniform [LOUT] {Total: 1.0, Self: 1.0, Net: 0.0} [0, 0, -1, -1] [0, 0, 0 -> 0MB]
(3) LiteralOp -1 [LOUT] {Total: 1.0, Self: 1.0, Net: 0.0} [0, 0, -1, -1] [0, 0, 0 -> 0MB]
(6) LiteralOp 10 [LOUT] {Total: 1.0, Self: 1.0, Net: 0.0} [0, 0, -1, -1] [0, 0, 0 -> 0MB]
(9) LiteralOp 2 [LOUT] {Total: 1.0, Self: 1.0, Net: 0.0} [0, 0, -1, -1] [0, 0, 0 -> 0MB]

Next Tasks

The upcoming tasks are as follows:

  1. Implement exception cases for costestimator.
  2. Add test DML scripts that include for loops.
  3. Brainstorm an implementation method for a globally optimal federated plan considering control flow.

Please review the updates and let me know if further adjustments are needed.

Copy link

codecov bot commented Dec 21, 2024

Codecov Report

Attention: Patch coverage is 89.41176% with 18 lines in your changes missing coverage. Please review.

Project coverage is 72.05%. Comparing base (d3fcfb1) to head (0fcfc6d).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...ache/sysds/hops/fedplanner/FederatedMemoTable.java 89.38% 6 Missing and 6 partials ⚠️
...s/hops/fedplanner/FederatedPlanCostEnumerator.java 86.48% 2 Missing and 3 partials ⚠️
...ds/hops/fedplanner/FederatedPlanCostEstimator.java 95.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff              @@
##               main    #2147      +/-   ##
============================================
+ Coverage     72.03%   72.05%   +0.01%     
- Complexity    43937    43961      +24     
============================================
  Files          1441     1443       +2     
  Lines        166106   166239     +133     
  Branches      32428    32453      +25     
============================================
+ Hits         119655   119776     +121     
- Misses        37199    37212      +13     
+ Partials       9252     9251       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@mboehm7
Copy link
Contributor

mboehm7 commented Dec 21, 2024

LGTM - Thanks for the patch @min-guk. During the merge I fixed the formatting (tabs over spaces in Java), added missing licenses, and fixed remaining javadoc issues.

@mboehm7 mboehm7 closed this in 29b4d92 Dec 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

2 participants