diff --git a/xla/pjrt/BUILD b/xla/pjrt/BUILD index 4495166dc187e..32ac8eb561f13 100644 --- a/xla/pjrt/BUILD +++ b/xla/pjrt/BUILD @@ -254,6 +254,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/pjrt/pjrt_executable.cc b/xla/pjrt/pjrt_executable.cc index f9ed15b43aeb2..d58ad99b0377d 100644 --- a/xla/pjrt/pjrt_executable.cc +++ b/xla/pjrt/pjrt_executable.cc @@ -89,6 +89,16 @@ StatusOr CompileOptions::ToProto() const { return output; } +void CompileOptions::SerializeEnvOptionOverrides( + google::protobuf::Map* + output_env_option_overrides) const { + for (auto& env_option_override : env_option_overrides) { + auto& tmp = (*output_env_option_overrides)[env_option_override.first]; + std::visit([&](const auto& arg) { SetOptionOverride(tmp, arg); }, + env_option_override.second); + } +} + StatusOr CompileOptions::FromProto( const CompileOptionsProto& proto) { if (!proto.serialized_multi_slice_config().empty()) { @@ -112,36 +122,8 @@ StatusOr CompileOptions::FromProto( output.executable_build_options = executable_build_options; output.compile_portable_executable = proto.compile_portable_executable(); output.profile_version = proto.profile_version(); - for (auto& env_option_override : proto.env_option_overrides()) { - switch (env_option_override.second.value_case()) { - case OptionOverrideProto::kStringField: - output.env_option_overrides.push_back( - {env_option_override.first, - CompileOptions::OptionOverride( - env_option_override.second.string_field())}); - break; - case OptionOverrideProto::kBoolField: - output.env_option_overrides.push_back( - {env_option_override.first, - CompileOptions::OptionOverride( - env_option_override.second.bool_field())}); - break; - case OptionOverrideProto::kIntField: - output.env_option_overrides.push_back( - {env_option_override.first, - CompileOptions::OptionOverride( - env_option_override.second.int_field())}); - break; - case OptionOverrideProto::kDoubleField: - output.env_option_overrides.push_back( - {env_option_override.first, - CompileOptions::OptionOverride( - env_option_override.second.double_field())}); - break; - case OptionOverrideProto::VALUE_NOT_SET: - return InternalError("OptionOverrideProto value not set."); - } - } + TF_ASSIGN_OR_RETURN(output.env_option_overrides, + LoadEnvOptionOverrides(proto.env_option_overrides())); return output; } @@ -369,6 +351,40 @@ PjRtExecutableUtil::RunHloCostAnalysis( return ret; } +StatusOr>> +CompileOptions::LoadEnvOptionOverrides( + const google::protobuf::Map& + env_option_overrides) { + std::vector> result; + for (auto& env_option_override : env_option_overrides) { + switch (env_option_override.second.value_case()) { + case OptionOverrideProto::kStringField: + result.push_back({env_option_override.first, + CompileOptions::OptionOverride( + env_option_override.second.string_field())}); + break; + case OptionOverrideProto::kBoolField: + result.push_back({env_option_override.first, + CompileOptions::OptionOverride( + env_option_override.second.bool_field())}); + break; + case OptionOverrideProto::kIntField: + result.push_back({env_option_override.first, + CompileOptions::OptionOverride( + env_option_override.second.int_field())}); + break; + case OptionOverrideProto::kDoubleField: + result.push_back({env_option_override.first, + CompileOptions::OptionOverride( + env_option_override.second.double_field())}); + break; + case OptionOverrideProto::VALUE_NOT_SET: + return InternalError("OptionOverrideProto value not set."); + } + } + return result; +} + Status CompileOptions::ApplyOption(const std::string& key, const OptionOverride& value) { if (auto* xla_field = xla::DebugOptions::descriptor()->FindFieldByName(key)) { diff --git a/xla/pjrt/pjrt_executable.h b/xla/pjrt/pjrt_executable.h index ab88f87342ce8..28e40eb515b43 100644 --- a/xla/pjrt/pjrt_executable.h +++ b/xla/pjrt/pjrt_executable.h @@ -43,6 +43,7 @@ limitations under the License. #include "xla/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/protobuf.h" namespace xla { @@ -112,6 +113,16 @@ struct CompileOptions { Status ApplyOptionFromString(const tsl::protobuf::FieldDescriptor* field, const std::string& value); + static StatusOr< + std::vector>> + LoadEnvOptionOverrides( + const google::protobuf::Map& + env_option_overrides); + + void SerializeEnvOptionOverrides( + google::protobuf::Map* + output_env_option_overrides) const; + // Serialize the CompileOptions into a CompileOptionsProto. StatusOr ToProto() const;