|
39 | 39 | from torchrec.distributed.planner.types import (
|
40 | 40 | Enumerator,
|
41 | 41 | hash_planner_context_inputs,
|
| 42 | + hash_planner_context_inputs_str, |
42 | 43 | ParameterConstraints,
|
43 | 44 | Partitioner,
|
44 | 45 | PerfModel,
|
45 | 46 | PlanDebugStats,
|
| 47 | + PlanLoader, |
46 | 48 | PlannerError,
|
47 | 49 | PlannerErrorType,
|
48 | 50 | Proposer,
|
@@ -118,6 +120,60 @@ def to_sharding_plan(
|
118 | 120 | return ShardingPlan(plan)
|
119 | 121 |
|
120 | 122 |
|
| 123 | +def extract_plan( |
| 124 | + search_space: List[ShardingOption], |
| 125 | + loaded_sharding_options: Dict[int, ShardingOption], |
| 126 | +) -> List[ShardingOption]: |
| 127 | + |
| 128 | + new_search_space: List[ShardingOption] = [] |
| 129 | + seen_hash_set = set() |
| 130 | + |
| 131 | + for so in search_space: |
| 132 | + |
| 133 | + # Validate that the storage hash is unique and isn't mapped to multiple sharding options |
| 134 | + if so.storage_hash() in seen_hash_set: |
| 135 | + raise PlannerError( |
| 136 | + error_type=PlannerErrorType.PLAN_LOADING_FAILED, |
| 137 | + message=f"Found a duplicate storage hash {so.storage_hash()} for FQNs {[so.fqn for so in search_space]}\n", |
| 138 | + ) |
| 139 | + else: |
| 140 | + seen_hash_set.add(so.storage_hash()) |
| 141 | + |
| 142 | + loaded_so = loaded_sharding_options.get(so.storage_hash()) |
| 143 | + if loaded_so is not None: |
| 144 | + new_search_space.append( |
| 145 | + ShardingOption( |
| 146 | + name=so.name, |
| 147 | + tensor=so.tensor, |
| 148 | + module=so.module, |
| 149 | + input_lengths=so.input_lengths, |
| 150 | + batch_size=so.batch_size, |
| 151 | + compute_kernel=so.compute_kernel, |
| 152 | + sharding_type=so.sharding_type, |
| 153 | + partition_by=so.partition_by, |
| 154 | + # We only need to update the shards from the loaded plan |
| 155 | + shards=loaded_so.shards, |
| 156 | + cache_params=so.cache_params, |
| 157 | + enforce_hbm=so.enforce_hbm, |
| 158 | + stochastic_rounding=so.stochastic_rounding, |
| 159 | + bounds_check_mode=so.bounds_check_mode, |
| 160 | + dependency=so.dependency, |
| 161 | + is_pooled=so.is_pooled, |
| 162 | + feature_names=so.feature_names, |
| 163 | + output_dtype=so.output_dtype, |
| 164 | + key_value_params=so.key_value_params, |
| 165 | + ) |
| 166 | + ) |
| 167 | + |
| 168 | + # Validate that populated search space is the same size as the enumerated search space |
| 169 | + if len(loaded_sharding_options) != len(new_search_space): |
| 170 | + raise PlannerError( |
| 171 | + error_type=PlannerErrorType.PLAN_LOADING_FAILED, |
| 172 | + message=f"Loaded sharding options from Storage, but not all search space is covered. Merged search space len {len(new_search_space)} != loaded Sharding options len {len(loaded_sharding_options)}\n", |
| 173 | + ) |
| 174 | + return new_search_space |
| 175 | + |
| 176 | + |
121 | 177 | def _merge_plans(best_plans: List[ShardingPlan]) -> ShardingPlan:
|
122 | 178 | if len(best_plans) == 1:
|
123 | 179 | return best_plans[0]
|
@@ -269,6 +325,22 @@ def hash_planner_context_inputs(self) -> int:
|
269 | 325 | self._constraints,
|
270 | 326 | )
|
271 | 327 |
|
| 328 | + def hash_planner_context_inputs_str(self) -> str: |
| 329 | + """ |
| 330 | + Generates a hash for all planner inputs except for partitioner, proposer, performance model, and stats. |
| 331 | + These are all the inputs needed to verify whether a previously generated sharding plan is still valid in a new context. |
| 332 | +
|
| 333 | + Returns: |
| 334 | + Generates a hash capturing topology, batch size, enumerator, storage reservation, stats and constraints. |
| 335 | + """ |
| 336 | + return hash_planner_context_inputs_str( |
| 337 | + self._topology, |
| 338 | + self._batch_size, |
| 339 | + self._enumerator, |
| 340 | + self._storage_reservation, |
| 341 | + self._constraints, |
| 342 | + ) |
| 343 | + |
272 | 344 |
|
273 | 345 | class EmbeddingShardingPlanner(EmbeddingPlannerBase):
|
274 | 346 | """
|
@@ -315,6 +387,7 @@ def __init__(
|
315 | 387 | List[Callable[[List[ShardingOption]], List[ShardingOption]]]
|
316 | 388 | ] = None,
|
317 | 389 | timeout_seconds: Optional[int] = None,
|
| 390 | + plan_loader: Optional[PlanLoader] = None, |
318 | 391 | ) -> None:
|
319 | 392 | super().__init__(
|
320 | 393 | topology=topology,
|
@@ -347,6 +420,8 @@ def __init__(
|
347 | 420 | else NoopPerfModel(topology=self._topology)
|
348 | 421 | )
|
349 | 422 |
|
| 423 | + self.plan_loader = plan_loader |
| 424 | + |
350 | 425 | self._num_proposals: int = 0
|
351 | 426 | self._num_plans: int = 0
|
352 | 427 | self._best_plan: Optional[List[ShardingOption]] = None
|
@@ -427,86 +502,113 @@ def plan(
|
427 | 502 | # No shardable parameters
|
428 | 503 | return ShardingPlan({})
|
429 | 504 |
|
430 |
| - proposal_cache: Dict[ |
431 |
| - Tuple[int, ...], |
432 |
| - Tuple[bool, Optional[List[ShardingOption]], Optional[float]], |
433 |
| - ] = {} |
434 |
| - |
435 |
| - for proposer in self._proposers: |
436 |
| - proposer.load(search_space=search_space, enumerator=self._enumerator) |
437 |
| - |
438 |
| - start = time.time() |
439 |
| - for proposer in self._proposers: |
440 |
| - proposal = proposer.propose() |
441 |
| - |
442 |
| - while proposal: |
443 |
| - end = time.time() |
444 |
| - elapsed = end - start |
445 |
| - if self._timeout_seconds: |
446 |
| - if elapsed > self._timeout_seconds: |
447 |
| - logger.info( |
448 |
| - f"Exceeded time limit of {self._timeout_seconds}s. Took {elapsed}s" |
449 |
| - ) |
450 |
| - break |
451 |
| - proposal_key = tuple(sorted(map(hash, proposal))) |
452 |
| - if proposal_key in proposal_cache: |
453 |
| - partitionable, plan, perf_rating = proposal_cache[proposal_key] |
454 |
| - proposer.feedback( |
455 |
| - partitionable=partitionable, |
456 |
| - plan=plan, |
457 |
| - perf_rating=perf_rating, |
458 |
| - storage_constraint=storage_constraint, |
459 |
| - ) |
460 |
| - proposal = proposer.propose() |
461 |
| - continue |
462 |
| - |
463 |
| - self._num_proposals += 1 |
464 |
| - try: |
465 |
| - # plan is just proposal where shard.rank is populated |
466 |
| - plan = self._partitioner.partition( |
467 |
| - proposal=proposal, |
468 |
| - storage_constraint=storage_constraint, |
469 |
| - ) |
470 |
| - self._num_plans += 1 |
471 |
| - perf_rating = self._perf_model.rate(plan=plan) |
472 |
| - if perf_rating < best_perf_rating: |
473 |
| - best_perf_rating = perf_rating |
474 |
| - best_plan = copy.deepcopy(plan) |
475 |
| - proposal_cache[proposal_key] = (True, plan, perf_rating) |
476 |
| - proposer.feedback( |
477 |
| - partitionable=True, |
478 |
| - plan=plan, |
479 |
| - perf_rating=perf_rating, |
480 |
| - storage_constraint=storage_constraint, |
481 |
| - ) |
482 |
| - except PlannerError as planner_error: |
483 |
| - last_planner_error = planner_error |
484 |
| - # shallow copy of the proposal |
485 |
| - last_proposal: List[ShardingOption] = copy.copy(proposal) |
486 |
| - current_storage = cast( |
487 |
| - Storage, |
488 |
| - reduce( |
489 |
| - lambda x, y: x + y, |
490 |
| - [ |
491 |
| - shard.storage |
492 |
| - for option in proposal |
493 |
| - for shard in option.shards |
494 |
| - ], |
495 |
| - ), |
496 |
| - ) |
497 |
| - if current_storage < lowest_storage: |
498 |
| - lowest_storage = current_storage |
499 |
| - proposal_cache[proposal_key] = (False, proposal, None) |
500 |
| - proposer.feedback( |
501 |
| - partitionable=False, |
502 |
| - plan=proposal, |
503 |
| - storage_constraint=storage_constraint, |
504 |
| - ) |
| 505 | + loaded_sharding_options = None |
| 506 | + loaded_best_plan: List[ShardingOption] = [] |
| 507 | + |
| 508 | + if self.plan_loader is not None: |
| 509 | + # validate plan before loading |
| 510 | + self._loader_plan_validation( |
| 511 | + current_planner_hash=self.hash_planner_context_inputs_str(), |
| 512 | + # pyre-fixme[16]: `Optional` has no attribute `plan_context_hash`. |
| 513 | + loaded_plan_hash=self.plan_loader.plan_context_hash(), |
| 514 | + ) |
| 515 | + # pyre-ignore |
| 516 | + loaded_sharding_options = self.plan_loader.load() |
| 517 | + if loaded_sharding_options is not None: |
| 518 | + # Merging sharding options from loaded plan with enumerated search space |
| 519 | + loaded_best_plan = extract_plan( |
| 520 | + search_space=search_space, |
| 521 | + loaded_sharding_options=loaded_sharding_options, |
| 522 | + ) |
| 523 | + |
| 524 | + # Loaded plan is validated successfully and can be used for generate the sharding plan, skipping new plan generation. |
| 525 | + if loaded_best_plan: |
| 526 | + logger.info( |
| 527 | + # pyre-ignore |
| 528 | + f"Loded sharding options from Storage with plan id: {self.plan_loader.get_plan_id()} skipping new plan generation" |
| 529 | + ) |
| 530 | + best_plan = copy.deepcopy(loaded_best_plan) |
| 531 | + else: |
| 532 | + proposal_cache: Dict[ |
| 533 | + Tuple[int, ...], |
| 534 | + Tuple[bool, Optional[List[ShardingOption]], Optional[float]], |
| 535 | + ] = {} |
| 536 | + |
| 537 | + for proposer in self._proposers: |
| 538 | + proposer.load(search_space=search_space, enumerator=self._enumerator) |
505 | 539 |
|
506 |
| - # clear shard.rank for each sharding_option |
507 |
| - reset_shard_rank(proposal) |
| 540 | + start = time.time() |
| 541 | + for proposer in self._proposers: |
508 | 542 | proposal = proposer.propose()
|
509 | 543 |
|
| 544 | + while proposal: |
| 545 | + end = time.time() |
| 546 | + elapsed = end - start |
| 547 | + if self._timeout_seconds: |
| 548 | + if elapsed > self._timeout_seconds: |
| 549 | + logger.info( |
| 550 | + f"Exceeded time limit of {self._timeout_seconds}s. Took {elapsed}s" |
| 551 | + ) |
| 552 | + break |
| 553 | + proposal_key = tuple(sorted(map(hash, proposal))) |
| 554 | + if proposal_key in proposal_cache: |
| 555 | + partitionable, plan, perf_rating = proposal_cache[proposal_key] |
| 556 | + proposer.feedback( |
| 557 | + partitionable=partitionable, |
| 558 | + plan=plan, |
| 559 | + perf_rating=perf_rating, |
| 560 | + storage_constraint=storage_constraint, |
| 561 | + ) |
| 562 | + proposal = proposer.propose() |
| 563 | + continue |
| 564 | + |
| 565 | + self._num_proposals += 1 |
| 566 | + try: |
| 567 | + # plan is just proposal where shard.rank is populated |
| 568 | + plan = self._partitioner.partition( |
| 569 | + proposal=proposal, |
| 570 | + storage_constraint=storage_constraint, |
| 571 | + ) |
| 572 | + self._num_plans += 1 |
| 573 | + perf_rating = self._perf_model.rate(plan=plan) |
| 574 | + if perf_rating < best_perf_rating: |
| 575 | + best_perf_rating = perf_rating |
| 576 | + best_plan = copy.deepcopy(plan) |
| 577 | + proposal_cache[proposal_key] = (True, plan, perf_rating) |
| 578 | + proposer.feedback( |
| 579 | + partitionable=True, |
| 580 | + plan=plan, |
| 581 | + perf_rating=perf_rating, |
| 582 | + storage_constraint=storage_constraint, |
| 583 | + ) |
| 584 | + except PlannerError as planner_error: |
| 585 | + last_planner_error = planner_error |
| 586 | + # shallow copy of the proposal |
| 587 | + last_proposal: List[ShardingOption] = copy.copy(proposal) |
| 588 | + current_storage = cast( |
| 589 | + Storage, |
| 590 | + reduce( |
| 591 | + lambda x, y: x + y, |
| 592 | + [ |
| 593 | + shard.storage |
| 594 | + for option in proposal |
| 595 | + for shard in option.shards |
| 596 | + ], |
| 597 | + ), |
| 598 | + ) |
| 599 | + if current_storage < lowest_storage: |
| 600 | + lowest_storage = current_storage |
| 601 | + proposal_cache[proposal_key] = (False, proposal, None) |
| 602 | + proposer.feedback( |
| 603 | + partitionable=False, |
| 604 | + plan=proposal, |
| 605 | + storage_constraint=storage_constraint, |
| 606 | + ) |
| 607 | + |
| 608 | + # clear shard.rank for each sharding_option |
| 609 | + reset_shard_rank(proposal) |
| 610 | + proposal = proposer.propose() |
| 611 | + |
510 | 612 | if best_plan:
|
511 | 613 | for callback in self._callbacks:
|
512 | 614 | best_plan = callback(best_plan)
|
@@ -607,6 +709,32 @@ def plan(
|
607 | 709 | + last_planner_error_info,
|
608 | 710 | )
|
609 | 711 |
|
| 712 | + def _loader_plan_validation( |
| 713 | + self, current_planner_hash: str, loaded_plan_hash: Optional[str] |
| 714 | + ) -> None: |
| 715 | + """ |
| 716 | + Validates that the current planner context hash matches the loaded plan context hash. |
| 717 | +
|
| 718 | + Args: |
| 719 | + current_planner_hash (str): Hash from current planner context |
| 720 | + loaded_plan_hash (Optional[str]): Hash from loaded plan context |
| 721 | +
|
| 722 | + Raises: |
| 723 | + PlannerError: If hashes don't match |
| 724 | + """ |
| 725 | + if loaded_plan_hash is not None and current_planner_hash != loaded_plan_hash: |
| 726 | + # pyre-fixme[16]: `Optional` has no attribute `get_plan_id`. |
| 727 | + plan_id = self.plan_loader.get_plan_id() if self.plan_loader else None |
| 728 | + error_msg = ( |
| 729 | + f"Planner input context mismatch detected for {plan_id} and current planner set up:" |
| 730 | + f"\nCurrent planner hash: {current_planner_hash}, Loaded plan hash: {loaded_plan_hash}" |
| 731 | + ) |
| 732 | + raise PlannerError( |
| 733 | + error_type=PlannerErrorType.PLANNER_INPUT_CONTEXT_MISMATCH, |
| 734 | + message="Unable to load, because of planner input mismatch - cannot validate this plan is the best plan for current context.. \n" |
| 735 | + + error_msg, |
| 736 | + ) |
| 737 | + |
610 | 738 |
|
611 | 739 | class HeteroEmbeddingShardingPlanner(ShardingPlanner):
|
612 | 740 | """
|
|
0 commit comments