Skip to content

Commit ff9a03d

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Integrating PlanLoader within OSS Planner (#3376)
Summary: Integrate PlanLoader functionality within the EmbeddingShardingPlanner to enable loading and reusing pre-computed sharding plans. This integration extends the OSS planner with plan loading capabilities. This diff includes: * PlanLoader Integration in EmbeddingShardingPlanner: - Added optional `plan_loader` parameter to EmbeddingShardingPlanner constructor - Integrated plan validation using context hash comparison to ensure loaded plans are compatible with current planner configuration - Fallback to normal planning when plan loader returns null * Plan Loading Workflow:Check if loaded plan context hash matches current planner context * If mismatch detected → raise PlannerError * If validation passes → load sharding options from storage * Map loaded sharding options to current search space using storage_hash * Skip planning phase and use pre-computed plan if available * Search Space Reconstruction: * Mapping of loaded sharding options to enumerated search space * Preserving all original ShardingOption metadata while replacing shard assignments Differential Revision: D81279558
1 parent e24b7fd commit ff9a03d

File tree

3 files changed

+616
-80
lines changed

3 files changed

+616
-80
lines changed

torchrec/distributed/planner/planners.py

Lines changed: 205 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,12 @@
3939
from torchrec.distributed.planner.types import (
4040
Enumerator,
4141
hash_planner_context_inputs,
42+
hash_planner_context_inputs_str,
4243
ParameterConstraints,
4344
Partitioner,
4445
PerfModel,
4546
PlanDebugStats,
47+
PlanLoader,
4648
PlannerError,
4749
PlannerErrorType,
4850
Proposer,
@@ -118,6 +120,60 @@ def to_sharding_plan(
118120
return ShardingPlan(plan)
119121

120122

123+
def extract_plan(
124+
search_space: List[ShardingOption],
125+
loaded_sharding_options: Dict[int, ShardingOption],
126+
) -> List[ShardingOption]:
127+
128+
new_search_space: List[ShardingOption] = []
129+
seen_hash_set = set()
130+
131+
for so in search_space:
132+
133+
# Validate that the storage hash is unique and isn't mapped to multiple sharding options
134+
if so.storage_hash() in seen_hash_set:
135+
raise PlannerError(
136+
error_type=PlannerErrorType.PLAN_LOADING_FAILED,
137+
message=f"Found a duplicate storage hash {so.storage_hash()} for FQNs {[so.fqn for so in search_space]}\n",
138+
)
139+
else:
140+
seen_hash_set.add(so.storage_hash())
141+
142+
loaded_so = loaded_sharding_options.get(so.storage_hash())
143+
if loaded_so is not None:
144+
new_search_space.append(
145+
ShardingOption(
146+
name=so.name,
147+
tensor=so.tensor,
148+
module=so.module,
149+
input_lengths=so.input_lengths,
150+
batch_size=so.batch_size,
151+
compute_kernel=so.compute_kernel,
152+
sharding_type=so.sharding_type,
153+
partition_by=so.partition_by,
154+
# We only need to update the shards from the loaded plan
155+
shards=loaded_so.shards,
156+
cache_params=so.cache_params,
157+
enforce_hbm=so.enforce_hbm,
158+
stochastic_rounding=so.stochastic_rounding,
159+
bounds_check_mode=so.bounds_check_mode,
160+
dependency=so.dependency,
161+
is_pooled=so.is_pooled,
162+
feature_names=so.feature_names,
163+
output_dtype=so.output_dtype,
164+
key_value_params=so.key_value_params,
165+
)
166+
)
167+
168+
# Validate that populated search space is the same size as the enumerated search space
169+
if len(loaded_sharding_options) != len(new_search_space):
170+
raise PlannerError(
171+
error_type=PlannerErrorType.PLAN_LOADING_FAILED,
172+
message=f"Loaded sharding options from Storage, but not all search space is covered. Merged search space len {len(new_search_space)} != loaded Sharding options len {len(loaded_sharding_options)}\n",
173+
)
174+
return new_search_space
175+
176+
121177
def _merge_plans(best_plans: List[ShardingPlan]) -> ShardingPlan:
122178
if len(best_plans) == 1:
123179
return best_plans[0]
@@ -269,6 +325,22 @@ def hash_planner_context_inputs(self) -> int:
269325
self._constraints,
270326
)
271327

328+
def hash_planner_context_inputs_str(self) -> str:
329+
"""
330+
Generates a hash for all planner inputs except for partitioner, proposer, performance model, and stats.
331+
These are all the inputs needed to verify whether a previously generated sharding plan is still valid in a new context.
332+
333+
Returns:
334+
Generates a hash capturing topology, batch size, enumerator, storage reservation, stats and constraints.
335+
"""
336+
return hash_planner_context_inputs_str(
337+
self._topology,
338+
self._batch_size,
339+
self._enumerator,
340+
self._storage_reservation,
341+
self._constraints,
342+
)
343+
272344

273345
class EmbeddingShardingPlanner(EmbeddingPlannerBase):
274346
"""
@@ -315,6 +387,7 @@ def __init__(
315387
List[Callable[[List[ShardingOption]], List[ShardingOption]]]
316388
] = None,
317389
timeout_seconds: Optional[int] = None,
390+
plan_loader: Optional[PlanLoader] = None,
318391
) -> None:
319392
super().__init__(
320393
topology=topology,
@@ -347,6 +420,8 @@ def __init__(
347420
else NoopPerfModel(topology=self._topology)
348421
)
349422

423+
self.plan_loader = plan_loader
424+
350425
self._num_proposals: int = 0
351426
self._num_plans: int = 0
352427
self._best_plan: Optional[List[ShardingOption]] = None
@@ -427,86 +502,113 @@ def plan(
427502
# No shardable parameters
428503
return ShardingPlan({})
429504

430-
proposal_cache: Dict[
431-
Tuple[int, ...],
432-
Tuple[bool, Optional[List[ShardingOption]], Optional[float]],
433-
] = {}
434-
435-
for proposer in self._proposers:
436-
proposer.load(search_space=search_space, enumerator=self._enumerator)
437-
438-
start = time.time()
439-
for proposer in self._proposers:
440-
proposal = proposer.propose()
441-
442-
while proposal:
443-
end = time.time()
444-
elapsed = end - start
445-
if self._timeout_seconds:
446-
if elapsed > self._timeout_seconds:
447-
logger.info(
448-
f"Exceeded time limit of {self._timeout_seconds}s. Took {elapsed}s"
449-
)
450-
break
451-
proposal_key = tuple(sorted(map(hash, proposal)))
452-
if proposal_key in proposal_cache:
453-
partitionable, plan, perf_rating = proposal_cache[proposal_key]
454-
proposer.feedback(
455-
partitionable=partitionable,
456-
plan=plan,
457-
perf_rating=perf_rating,
458-
storage_constraint=storage_constraint,
459-
)
460-
proposal = proposer.propose()
461-
continue
462-
463-
self._num_proposals += 1
464-
try:
465-
# plan is just proposal where shard.rank is populated
466-
plan = self._partitioner.partition(
467-
proposal=proposal,
468-
storage_constraint=storage_constraint,
469-
)
470-
self._num_plans += 1
471-
perf_rating = self._perf_model.rate(plan=plan)
472-
if perf_rating < best_perf_rating:
473-
best_perf_rating = perf_rating
474-
best_plan = copy.deepcopy(plan)
475-
proposal_cache[proposal_key] = (True, plan, perf_rating)
476-
proposer.feedback(
477-
partitionable=True,
478-
plan=plan,
479-
perf_rating=perf_rating,
480-
storage_constraint=storage_constraint,
481-
)
482-
except PlannerError as planner_error:
483-
last_planner_error = planner_error
484-
# shallow copy of the proposal
485-
last_proposal: List[ShardingOption] = copy.copy(proposal)
486-
current_storage = cast(
487-
Storage,
488-
reduce(
489-
lambda x, y: x + y,
490-
[
491-
shard.storage
492-
for option in proposal
493-
for shard in option.shards
494-
],
495-
),
496-
)
497-
if current_storage < lowest_storage:
498-
lowest_storage = current_storage
499-
proposal_cache[proposal_key] = (False, proposal, None)
500-
proposer.feedback(
501-
partitionable=False,
502-
plan=proposal,
503-
storage_constraint=storage_constraint,
504-
)
505+
loaded_sharding_options = None
506+
loaded_best_plan: List[ShardingOption] = []
507+
508+
if self.plan_loader is not None:
509+
# validate plan before loading
510+
self._loader_plan_validation(
511+
current_planner_hash=self.hash_planner_context_inputs_str(),
512+
# pyre-fixme[16]: `Optional` has no attribute `plan_context_hash`.
513+
loaded_plan_hash=self.plan_loader.plan_context_hash(),
514+
)
515+
# pyre-ignore
516+
loaded_sharding_options = self.plan_loader.load()
517+
if loaded_sharding_options is not None:
518+
# Merging sharding options from loaded plan with enumerated search space
519+
loaded_best_plan = extract_plan(
520+
search_space=search_space,
521+
loaded_sharding_options=loaded_sharding_options,
522+
)
523+
524+
# Loaded plan is validated successfully and can be used for generate the sharding plan, skipping new plan generation.
525+
if loaded_best_plan:
526+
logger.info(
527+
# pyre-ignore
528+
f"Loded sharding options from Storage with plan id: {self.plan_loader.get_plan_id()} skipping new plan generation"
529+
)
530+
best_plan = copy.deepcopy(loaded_best_plan)
531+
else:
532+
proposal_cache: Dict[
533+
Tuple[int, ...],
534+
Tuple[bool, Optional[List[ShardingOption]], Optional[float]],
535+
] = {}
536+
537+
for proposer in self._proposers:
538+
proposer.load(search_space=search_space, enumerator=self._enumerator)
505539

506-
# clear shard.rank for each sharding_option
507-
reset_shard_rank(proposal)
540+
start = time.time()
541+
for proposer in self._proposers:
508542
proposal = proposer.propose()
509543

544+
while proposal:
545+
end = time.time()
546+
elapsed = end - start
547+
if self._timeout_seconds:
548+
if elapsed > self._timeout_seconds:
549+
logger.info(
550+
f"Exceeded time limit of {self._timeout_seconds}s. Took {elapsed}s"
551+
)
552+
break
553+
proposal_key = tuple(sorted(map(hash, proposal)))
554+
if proposal_key in proposal_cache:
555+
partitionable, plan, perf_rating = proposal_cache[proposal_key]
556+
proposer.feedback(
557+
partitionable=partitionable,
558+
plan=plan,
559+
perf_rating=perf_rating,
560+
storage_constraint=storage_constraint,
561+
)
562+
proposal = proposer.propose()
563+
continue
564+
565+
self._num_proposals += 1
566+
try:
567+
# plan is just proposal where shard.rank is populated
568+
plan = self._partitioner.partition(
569+
proposal=proposal,
570+
storage_constraint=storage_constraint,
571+
)
572+
self._num_plans += 1
573+
perf_rating = self._perf_model.rate(plan=plan)
574+
if perf_rating < best_perf_rating:
575+
best_perf_rating = perf_rating
576+
best_plan = copy.deepcopy(plan)
577+
proposal_cache[proposal_key] = (True, plan, perf_rating)
578+
proposer.feedback(
579+
partitionable=True,
580+
plan=plan,
581+
perf_rating=perf_rating,
582+
storage_constraint=storage_constraint,
583+
)
584+
except PlannerError as planner_error:
585+
last_planner_error = planner_error
586+
# shallow copy of the proposal
587+
last_proposal: List[ShardingOption] = copy.copy(proposal)
588+
current_storage = cast(
589+
Storage,
590+
reduce(
591+
lambda x, y: x + y,
592+
[
593+
shard.storage
594+
for option in proposal
595+
for shard in option.shards
596+
],
597+
),
598+
)
599+
if current_storage < lowest_storage:
600+
lowest_storage = current_storage
601+
proposal_cache[proposal_key] = (False, proposal, None)
602+
proposer.feedback(
603+
partitionable=False,
604+
plan=proposal,
605+
storage_constraint=storage_constraint,
606+
)
607+
608+
# clear shard.rank for each sharding_option
609+
reset_shard_rank(proposal)
610+
proposal = proposer.propose()
611+
510612
if best_plan:
511613
for callback in self._callbacks:
512614
best_plan = callback(best_plan)
@@ -607,6 +709,32 @@ def plan(
607709
+ last_planner_error_info,
608710
)
609711

712+
def _loader_plan_validation(
713+
self, current_planner_hash: str, loaded_plan_hash: Optional[str]
714+
) -> None:
715+
"""
716+
Validates that the current planner context hash matches the loaded plan context hash.
717+
718+
Args:
719+
current_planner_hash (str): Hash from current planner context
720+
loaded_plan_hash (Optional[str]): Hash from loaded plan context
721+
722+
Raises:
723+
PlannerError: If hashes don't match
724+
"""
725+
if loaded_plan_hash is not None and current_planner_hash != loaded_plan_hash:
726+
# pyre-fixme[16]: `Optional` has no attribute `get_plan_id`.
727+
plan_id = self.plan_loader.get_plan_id() if self.plan_loader else None
728+
error_msg = (
729+
f"Planner input context mismatch detected for {plan_id} and current planner set up:"
730+
f"\nCurrent planner hash: {current_planner_hash}, Loaded plan hash: {loaded_plan_hash}"
731+
)
732+
raise PlannerError(
733+
error_type=PlannerErrorType.PLANNER_INPUT_CONTEXT_MISMATCH,
734+
message="Unable to load, because of planner input mismatch - cannot validate this plan is the best plan for current context.. \n"
735+
+ error_msg,
736+
)
737+
610738

611739
class HeteroEmbeddingShardingPlanner(ShardingPlanner):
612740
"""

0 commit comments

Comments
 (0)