Skip to content

Commit

Permalink
Update MLIR passes
Browse files Browse the repository at this point in the history
  • Loading branch information
Tombana committed Jun 17, 2024
1 parent 679e355 commit 2e95843
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 51 deletions.
3 changes: 3 additions & 0 deletions larq_compute_engine/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,11 @@ cc_library(
"@local_tsl//tsl/platform:statusor",
"@org_tensorflow//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"@org_tensorflow//tensorflow/compiler/mlir/lite:flatbuffer_export",
"@org_tensorflow//tensorflow/compiler/mlir/lite/debug",
"@org_tensorflow//tensorflow/compiler/mlir/lite/metrics:error_collector",
"@org_tensorflow//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"@org_tensorflow//tensorflow/compiler/mlir/lite/stablehlo:op_stat_pass",
"@org_tensorflow//tensorflow/compiler/mlir/lite/stablehlo:stablehlo_util",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:error_util",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_freeze_variables",
],
Expand Down
6 changes: 4 additions & 2 deletions larq_compute_engine/mlir/tf_tfl_passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ void AddPreVariableFreezingTFToLCETFLConversionPasses(

// This decomposes resource ops like ResourceGather into read-variable op
// followed by gather. This is used when the saved model import path is used
// during which resources dont get frozen in the python layer.
// during which resources don't get frozen in the python layer.
pass_manager->addNestedPass<mlir::func::FuncOp>(
mlir::TFDevice::CreateDecomposeResourceOpsPass());

Expand Down Expand Up @@ -257,7 +257,9 @@ void AddPostVariableFreezingTFToLCETFLConversionPasses(

// Run quantization after all the floating point model conversion is
// completed.
if (quant_specs.RunPropagationAndRewriteQuantizationPasses()) {
if (quant_specs.RunPropagationAndRewriteQuantizationPasses() ||
quant_specs.qdq_conversion_mode !=
mlir::quant::QDQConversionMode::kQDQNone) {
AddQuantizationPasses(quant_specs, *pass_manager);
// Remove unnecessary QDQs while handling QAT models.
pass_manager->addNestedPass<mlir::func::FuncOp>(
Expand Down
93 changes: 45 additions & 48 deletions larq_compute_engine/mlir/tf_to_tfl_flatbuffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
#include "larq_compute_engine/mlir/tf_tfl_passes.h"
#include "larq_compute_engine/mlir/transforms/passes.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/PassManager.h"
#include "tensorflow/compiler/mlir/lite/debug/debug.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
Expand Down Expand Up @@ -55,7 +59,7 @@ class TruncateOpOrArgLocNameMapper : public OpOrArgLocNameMapper {
};

} // namespace
Status ConvertTFExecutorToTFLOrFlatbuffer(
absl::Status ConvertTFExecutorToTFLOrFlatbuffer(
mlir::ModuleOp module, bool export_to_mlir, const LCETarget target,
mlir::quant::QuantizationSpecs quant_specs,
const std::unordered_set<std::string>& saved_model_tags,
Expand All @@ -64,70 +68,59 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
// Explicitly disable dumping Op details on failures.
module.getContext()->printOpOnDiagnostic(false);

// Register a warning handler only log to std out.
mlir::ScopedDiagnosticHandler s(
module.getContext(), [](mlir::Diagnostic& diag) {
if (diag.getSeverity() == mlir::DiagnosticSeverity::Warning) {
for (auto& note : diag.getNotes()) {
std::cout << note.str() << "\n";
LOG(WARNING) << note.str() << "\n";
}
}
return mlir::failure();
});
mlir::DialectRegistry registry;
mlir::func::registerAllExtensions(registry);
module.getContext()->appendDialectRegistry(registry);

mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
/*propagate=*/true);
if (failed(IsValidGraph(module))) {
return statusHandler.ConsumeStatus();
}

mlir::PassManager pass_manager(module.getContext());
mlir::registerPassManagerCLOptions();
if (mlir::failed(mlir::applyPassManagerCLOptions(pass_manager))) {
// We don't return here as in the normal TF converter, since apparently this
// actually fails in our case, but the failure isn't terminal.
// return tensorflow::FromAbslStatus(
// absl::UnknownError("failed to apply MLIR pass manager CL options"));
return absl::InternalError("Failed to apply MLIR pass manager CL options.");
}
// DebugOptions::ir_dump_dir can be set for debugging
converter::DebugOptions debug_options;
InitPassManager(pass_manager, debug_options);

pass_manager.addInstrumentation(
std::make_unique<mlir::TFL::ErrorCollectorInstrumentation>(
pass_manager.getContext()));

if (mlir::failed(IsValidGraph(module))) {
return statusHandler.ConsumeStatus();
}

tensorflow::AddPreVariableFreezingTFToLCETFLConversionPasses(&pass_manager);
if (failed(pass_manager.run(module))) {
if (mlir::failed(pass_manager.run(module))) {
return statusHandler.ConsumeStatus();
}

// Freeze variables if a session is provided.
if (session.has_value()) {
mlir::TFL::ErrorCollectorInstrumentation collector(module.getContext());
if (mlir::failed(
mlir::tf_saved_model::FreezeVariables(module, session.value()))) {
auto status = statusHandler.ConsumeStatus();
mlir::TFL::ErrorCollector* collector =
mlir::TFL::ErrorCollector::GetErrorCollector();
if (!collector->CollectedErrors().empty()) {
return errors::InvalidArgument("Variable constant folding has failed.");
}
return status;
}
if (session.has_value() && mlir::failed(mlir::tf_saved_model::FreezeVariables(
module, session.value_or(nullptr)))) {
return statusHandler.Combine(
absl::InvalidArgumentError("Variable constant folding is failed."));
}

pass_manager.clear();

tensorflow::AddPostVariableFreezingTFToLCETFLConversionPasses(
saved_model_dir, quant_specs, &pass_manager, target);
if (failed(pass_manager.run(module))) {
auto status = statusHandler.ConsumeStatus();
mlir::TFL::ErrorCollector* collector =
mlir::TFL::ErrorCollector::GetErrorCollector();
for (const auto& error_data : collector->CollectedErrors()) {
if (error_data.subcomponent() == "FreezeGlobalTensorsPass") {
return errors::InvalidArgument("Variable constant folding is failed.");
}
}
return status;
if (mlir::failed(pass_manager.run(module))) {
return statusHandler.Combine(
absl::InvalidArgumentError("Variable constant folding failed."));
}

if (export_to_mlir) {
pass_manager.clear();
// Print out a detailed report of ops that are not converted to TFL ops.
pass_manager.addPass(mlir::odml::createPrintOpStatsPass(
mlir::odml::GetAcceptedTFLiteDialects()));
if (mlir::failed(pass_manager.run(module))) {
return statusHandler.ConsumeStatus();
}

llvm::raw_string_ostream os(*result);
module.print(os);
return statusHandler.ConsumeStatus();
Expand All @@ -142,14 +135,18 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
options.toco_flags = toco_flags;
options.saved_model_tags = saved_model_tags;
options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result)) {
return statusHandler.ConsumeStatus();
const bool serialize_stablehlo_ops = false;
if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result,
serialize_stablehlo_ops)) {
return statusHandler.Combine(
absl::InternalError("Could not translate MLIR to FlatBuffer."));
}

if (mlir::failed(module.verify())) {
return tensorflow::errors::Unknown("Final module is invalid");
if (mlir::failed(module.verifyInvariants())) {
return statusHandler.Combine(
absl::InternalError("Final module is invalid."));
}
return OkStatus();
return absl::OkStatus();
}

} // namespace tensorflow
2 changes: 1 addition & 1 deletion larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace tensorflow {
// This is a fork of ConvertTFExecutorToTFLOrFlatbuffer to enable custom
// OpOrArgLocNameMapper
// https://github.com/tensorflow/tensorflow/blob/v2.8.0/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h#L60-L78
Status ConvertTFExecutorToTFLOrFlatbuffer(
absl::Status ConvertTFExecutorToTFLOrFlatbuffer(
mlir::ModuleOp module, bool export_to_mlir, const LCETarget target,
mlir::quant::QuantizationSpecs quant_specs,
const std::unordered_set<std::string>& saved_model_tags,
Expand Down

0 comments on commit 2e95843

Please sign in to comment.