Skip to content

Commit

Permalink
Use sharding propagation when possible to obtain a default solution t…
Browse files Browse the repository at this point in the history
…o compare with the auto-sharding solution.

PiperOrigin-RevId: 565817073
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Sep 16, 2023
1 parent 27bfb08 commit 83d9984
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 28 deletions.
67 changes: 56 additions & 11 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2316,7 +2316,9 @@ AutoShardingSolverResult CallSolver(
const std::vector<NodeStrategyIdx>& s_hint,
int64_t memory_budget_per_device, bool crash_at_infinity_costs_check,
bool compute_iis, int64_t solver_timeout_in_seconds,
bool allow_alias_to_follower_conversion) {
bool allow_alias_to_follower_conversion,
const absl::flat_hash_map<std::string, const HloInstruction*>&
sharding_propagation_solution) {
// Serialize edges and edge costs to 1d numpy arrays
AutoShardingSolverRequest request;
request.num_nodes = leaf_strategies.size();
Expand Down Expand Up @@ -2346,10 +2348,16 @@ AutoShardingSolverResult CallSolver(
// Serialize node costs
for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) {
const StrategyVector* strategies = leaf_strategies[node_idx];
auto instruction_name = instructions.at(strategies->instruction_id)->name();
request.instruction_names.push_back(
absl::StrCat(instructions.at(strategies->instruction_id)->name(),
" (id: ", node_idx, ")"));
absl::StrCat(instruction_name, " (id: ", node_idx, ")"));
std::vector<double> ci, di, mi, pi;
auto default_strategy = HloSharding::Replicate();
auto iter = sharding_propagation_solution.find(instruction_name);
if (iter != sharding_propagation_solution.end()) {
CHECK(iter->second->has_sharding()) << iter->second->ToString();
default_strategy = iter->second->sharding();
}
for (NodeStrategyIdx j = 0; j < strategies->leaf_vector.size(); ++j) {
const ShardingStrategy& strategy = strategies->leaf_vector[j];
const HloSharding& sharding = strategy.output_sharding;
Expand All @@ -2359,7 +2367,7 @@ AutoShardingSolverResult CallSolver(
mi.push_back(strategy.memory_cost);
// TODO(moffitt): Revisit the default strategy below, which is currently
// defined as the "trivial sharding" in hlo_sharding.h
pi.push_back(sharding.IsReplicated() && !sharding.IsManual() ? 0.0 : 1.0);
pi.push_back(sharding == default_strategy ? 0.0 : 1.0);
}
request.c.push_back(ci);
request.d.push_back(di);
Expand Down Expand Up @@ -3974,7 +3982,9 @@ AutoShardingImplementation::AutoShardingImplementation(
StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(
HloModule* module,
const absl::flat_hash_set<std::string>& replicated_small_tensors,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
const absl::flat_hash_set<absl::string_view>& execution_threads,
const absl::flat_hash_map<std::string, const HloInstruction*>&
sharding_propagation_solution) {
if (!option_.enable) {
return AutoShardingResult::kModuleUnchanged;
}
Expand Down Expand Up @@ -4226,7 +4236,7 @@ StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(
if (!solver_option.load_solution_vector) {
auto solver_result =
Solve(*hlo_live_range, liveness_set, strategy_map, leaf_strategies,
cost_graph, alias_set, option_);
cost_graph, alias_set, option_, sharding_propagation_solution);
if (solver_result.skip_auto_sharding) {
return AutoShardingResult::kModuleUnchangedNoShardingPerfomed;
} else if (!solver_result.status.ok()) {
Expand Down Expand Up @@ -4317,6 +4327,13 @@ bool IsModuleManuallySharded(const HloModule* module) {
return false;
}

std::unique_ptr<HloModule> CloneModule(const HloModule* module) {
auto module_clone = module->Clone("");
module_clone->set_layout_canonicalization_callback(
module->layout_canonicalization_callback());
return module_clone;
}

StatusOr<bool> AutoSharding::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
Expand Down Expand Up @@ -4383,6 +4400,35 @@ StatusOr<bool> AutoSharding::Run(
mesh_shapes.push_back(option_.device_mesh_shape);
}

absl::flat_hash_map<std::string, const HloInstruction*>
sharding_propagation_solution;
std::unique_ptr<HloModule> module_with_default_solution = nullptr;
if (option_.use_sharding_propagation_for_default_shardings) {
module_with_default_solution = CloneModule(module);
// TODO(pratikf): Ensure that we're passing the correct customc all sharding
// helper to the sharding propagation pass.
auto sharding_prop = ShardingPropagation(
/*is_spmd */ true, /*propagate_metadata */ false,
/*allow_spmd_sharding_propagation_to_output*/
module->config().allow_spmd_sharding_propagation_to_output(),
/*allow_spmd_sharding_propagation_to_parameters */
absl::InlinedVector<bool, 1>{false},
/*cse_prevention_only */ false,
/*sharding_helper*/ nullptr);

CHECK_OK(sharding_prop.Run(module_with_default_solution.get(),
execution_threads));
LOG(INFO) << module_with_default_solution->ToString();
for (const auto computation :
module_with_default_solution->computations()) {
for (const auto instruction : computation->instructions()) {
if (instruction->has_sharding()) {
sharding_propagation_solution[instruction->name()] = instruction;
}
}
}
}

size_t num_meshes = mesh_shapes.size();
std::vector<std::unique_ptr<HloModule>> modules(num_meshes);
std::vector<StatusOr<AutoShardingResult>> changed(
Expand All @@ -4399,11 +4445,10 @@ StatusOr<bool> AutoSharding::Run(
AutoShardingOption this_option = option_;
this_option.device_mesh_shape = mesh_shapes[i];
auto pass = new AutoShardingImplementation(this_option);
auto module_clone = module->Clone("");
module_clone->set_layout_canonicalization_callback(
module->layout_canonicalization_callback());
auto pass_result = pass->RunAutoSharding(
module_clone.get(), replicated_small_tensors, execution_threads);
auto module_clone = CloneModule(module);
auto pass_result =
pass->RunAutoSharding(module_clone.get(), replicated_small_tensors,
execution_threads, sharding_propagation_solution);

changed[i] = pass_result;
objective_values[i] = pass->GetSolverOptimalObjectiveValue();
Expand Down
18 changes: 10 additions & 8 deletions xla/hlo/experimental/auto_sharding/auto_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ class AutoShardingImplementation {
StatusOr<AutoShardingResult> RunAutoSharding(
HloModule* module,
const absl::flat_hash_set<std::string>& replicated_small_tensors,
const absl::flat_hash_set<absl::string_view>& execution_threads);
const absl::flat_hash_set<absl::string_view>& execution_threads,
const absl::flat_hash_map<std::string, const HloInstruction*>&
sharding_propagation_solution = {});

// Removes SPMD annotations (if there are) to test AutoSharding on manually
// annotated graphs.
Expand Down Expand Up @@ -210,13 +212,13 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins,
const ClusterEnvironment& cluster_env);

// The high-level "recipe" for solving an Auto Sharding problem.
AutoShardingSolverResult Solve(const HloLiveRange& hlo_live_range,
const LivenessSet& liveness_set,
const StrategyMap& strategy_map,
const LeafStrategies& leaf_strategies,
const CostGraph& cost_graph,
const AliasSet& alias_set,
const AutoShardingOption& option);
AutoShardingSolverResult Solve(
const HloLiveRange& hlo_live_range, const LivenessSet& liveness_set,
const StrategyMap& strategy_map, const LeafStrategies& leaf_strategies,
const CostGraph& cost_graph, const AliasSet& alias_set,
const AutoShardingOption& option,
const absl::flat_hash_map<std::string, const HloInstruction*>&
sharding_propagation_solution = {});

// Populates temporal distance values.
void PopulateTemporalValues(const CostGraph& cost_graph,
Expand Down
16 changes: 8 additions & 8 deletions xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@ limitations under the License.
namespace xla {
namespace spmd {

AutoShardingSolverResult Solve(const HloLiveRange& hlo_live_range,
const LivenessSet& liveness_set,
const StrategyMap& strategy_map,
const LeafStrategies& leaf_strategies,
const CostGraph& cost_graph,
const AliasSet& alias_set,
const AutoShardingOption& option) {
AutoShardingSolverResult Solve(
const HloLiveRange& hlo_live_range, const LivenessSet& liveness_set,
const StrategyMap& strategy_map, const LeafStrategies& leaf_strategies,
const CostGraph& cost_graph, const AliasSet& alias_set,
const AutoShardingOption& option,
const absl::flat_hash_map<std::string, const HloInstruction*>&
sharding_propagation_solution) {
return CallSolver(
hlo_live_range, liveness_set, strategy_map, leaf_strategies, cost_graph,
alias_set, /*s_hint*/ {}, option.memory_budget_per_device,
/*crash_at_infinity_costs_check*/ !option.try_multiple_mesh_shapes,
/*compute_iis*/ true, option.solver_timeout_in_seconds,
option.allow_alias_to_follower_conversion);
option.allow_alias_to_follower_conversion, sharding_propagation_solution);
}

void PopulateTemporalValues(const CostGraph& cost_graph,
Expand Down
5 changes: 5 additions & 0 deletions xla/hlo/experimental/auto_sharding/auto_sharding_option.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ struct AutoShardingOption {
// sharding.
int64_t small_tensor_byte_size = 0;

// In order to obtain default sharding strategies for instructions to limit
// departures from the defaults, use sharding propagation instead of assuming
// a simple replicated default.
bool use_sharding_propagation_for_default_shardings = true;

std::string ToString() {
std::vector<std::string> lines;
lines.push_back(absl::StrCat("preserve_shardings: ", preserve_shardings));
Expand Down
4 changes: 3 additions & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ AutoShardingSolverResult CallSolver(
const std::vector<NodeStrategyIdx>& s_hint,
int64_t memory_budget_per_device, bool crash_at_infinity_costs_check,
bool compute_iis, int64_t solver_timeout_in_seconds,
bool allow_alias_to_follower_conversion);
bool allow_alias_to_follower_conversion,
const absl::flat_hash_map<std::string, const HloInstruction*>&
sharding_propagation_solution = {});

} // namespace spmd
} // namespace xla
Expand Down

0 comments on commit 83d9984

Please sign in to comment.