Skip to content

Commit

Permalink
Allow serializing env_options_overrides separately from the rest of
Browse files Browse the repository at this point in the history
CompileOptions.

PiperOrigin-RevId: 565813066
  • Loading branch information
pschuh authored and copybara-github committed Sep 16, 2023
1 parent 49314f4 commit 27bfb08
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 30 deletions.
1 change: 1 addition & 0 deletions xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:protobuf",
"@tsl//tsl/platform:statusor",
],
)
Expand Down
76 changes: 46 additions & 30 deletions xla/pjrt/pjrt_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ StatusOr<CompileOptionsProto> CompileOptions::ToProto() const {
return output;
}

void CompileOptions::SerializeEnvOptionOverrides(
google::protobuf::Map<std::string, xla::OptionOverrideProto>*
output_env_option_overrides) const {
for (auto& env_option_override : env_option_overrides) {
auto& tmp = (*output_env_option_overrides)[env_option_override.first];
std::visit([&](const auto& arg) { SetOptionOverride(tmp, arg); },
env_option_override.second);
}
}

StatusOr<CompileOptions> CompileOptions::FromProto(
const CompileOptionsProto& proto) {
if (!proto.serialized_multi_slice_config().empty()) {
Expand All @@ -112,36 +122,8 @@ StatusOr<CompileOptions> CompileOptions::FromProto(
output.executable_build_options = executable_build_options;
output.compile_portable_executable = proto.compile_portable_executable();
output.profile_version = proto.profile_version();
for (auto& env_option_override : proto.env_option_overrides()) {
switch (env_option_override.second.value_case()) {
case OptionOverrideProto::kStringField:
output.env_option_overrides.push_back(
{env_option_override.first,
CompileOptions::OptionOverride(
env_option_override.second.string_field())});
break;
case OptionOverrideProto::kBoolField:
output.env_option_overrides.push_back(
{env_option_override.first,
CompileOptions::OptionOverride(
env_option_override.second.bool_field())});
break;
case OptionOverrideProto::kIntField:
output.env_option_overrides.push_back(
{env_option_override.first,
CompileOptions::OptionOverride(
env_option_override.second.int_field())});
break;
case OptionOverrideProto::kDoubleField:
output.env_option_overrides.push_back(
{env_option_override.first,
CompileOptions::OptionOverride(
env_option_override.second.double_field())});
break;
case OptionOverrideProto::VALUE_NOT_SET:
return InternalError("OptionOverrideProto value not set.");
}
}
TF_ASSIGN_OR_RETURN(output.env_option_overrides,
LoadEnvOptionOverrides(proto.env_option_overrides()));
return output;
}

Expand Down Expand Up @@ -369,6 +351,40 @@ PjRtExecutableUtil::RunHloCostAnalysis(
return ret;
}

StatusOr<std::vector<std::pair<std::string, CompileOptions::OptionOverride>>>
CompileOptions::LoadEnvOptionOverrides(
const google::protobuf::Map<std::string, xla::OptionOverrideProto>&
env_option_overrides) {
std::vector<std::pair<std::string, CompileOptions::OptionOverride>> result;
for (auto& env_option_override : env_option_overrides) {
switch (env_option_override.second.value_case()) {
case OptionOverrideProto::kStringField:
result.push_back({env_option_override.first,
CompileOptions::OptionOverride(
env_option_override.second.string_field())});
break;
case OptionOverrideProto::kBoolField:
result.push_back({env_option_override.first,
CompileOptions::OptionOverride(
env_option_override.second.bool_field())});
break;
case OptionOverrideProto::kIntField:
result.push_back({env_option_override.first,
CompileOptions::OptionOverride(
env_option_override.second.int_field())});
break;
case OptionOverrideProto::kDoubleField:
result.push_back({env_option_override.first,
CompileOptions::OptionOverride(
env_option_override.second.double_field())});
break;
case OptionOverrideProto::VALUE_NOT_SET:
return InternalError("OptionOverrideProto value not set.");
}
}
return result;
}

Status CompileOptions::ApplyOption(const std::string& key,
const OptionOverride& value) {
if (auto* xla_field = xla::DebugOptions::descriptor()->FindFieldByName(key)) {
Expand Down
11 changes: 11 additions & 0 deletions xla/pjrt/pjrt_executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ limitations under the License.
#include "xla/statusor.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/protobuf.h"

namespace xla {

Expand Down Expand Up @@ -112,6 +113,16 @@ struct CompileOptions {
Status ApplyOptionFromString(const tsl::protobuf::FieldDescriptor* field,
const std::string& value);

static StatusOr<
std::vector<std::pair<std::string, CompileOptions::OptionOverride>>>
LoadEnvOptionOverrides(
const google::protobuf::Map<std::string, xla::OptionOverrideProto>&
env_option_overrides);

void SerializeEnvOptionOverrides(
google::protobuf::Map<std::string, xla::OptionOverrideProto>*
output_env_option_overrides) const;

// Serialize the CompileOptions into a CompileOptionsProto.
StatusOr<CompileOptionsProto> ToProto() const;

Expand Down

0 comments on commit 27bfb08

Please sign in to comment.