diff --git a/WORKSPACE b/WORKSPACE index bb36d88e..e989677d 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -124,9 +124,8 @@ rules_pkg_dependencies() # tip: use `bazel run @hedron_compile_commands//:refresh_all` http_archive( name = "hedron_compile_commands", - sha256 = "f01636585c3fb61c7c2dc74df511217cd5ad16427528ab33bc76bb34535f10a1", - strip_prefix = "bazel-compile-commands-extractor-a14ad3a64e7bf398ab48105aaa0348e032ac87f8", - url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/a14ad3a64e7bf398ab48105aaa0348e032ac87f8.tar.gz", + strip_prefix = "bazel-compile-commands-extractor-4f28899228fb3ad0126897876f147ca15026151e", + url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/4f28899228fb3ad0126897876f147ca15026151e.tar.gz", ) load("@hedron_compile_commands//:workspace_setup.bzl", "hedron_compile_commands_setup") diff --git a/zirgen/circuit/predicates/gen_predicates.cpp b/zirgen/circuit/predicates/gen_predicates.cpp index c8f3fcd3..5ec7d380 100644 --- a/zirgen/circuit/predicates/gen_predicates.cpp +++ b/zirgen/circuit/predicates/gen_predicates.cpp @@ -210,6 +210,15 @@ int main(int argc, char* argv[]) { addUnion( module, "union", [&](Assumption left, Assumption right) { return unionFunc(left, right); }); - module.optimize(); - module.getModule().walk([&](mlir::func::FuncOp func) { zirgen::emitRecursion(outputDir, func); }); + mlir::PassManager pm(module.getModule()->getContext()); + if (failed(applyPassManagerCLOptions(pm))) { + exit(1); + } + module.addOptimizationPasses(pm); + pm.nest().addPass(createEmitRecursionPass(outputDir)); + + if (failed(pm.run(module.getModule()))) { + llvm::errs() << "Unable to run recursion pipeline\n"; + exit(1); + } } diff --git a/zirgen/compiler/codegen/BUILD.bazel b/zirgen/compiler/codegen/BUILD.bazel index 64e93463..8e3d1097 100644 --- a/zirgen/compiler/codegen/BUILD.bazel +++ b/zirgen/compiler/codegen/BUILD.bazel @@ -1,7 +1,34 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + package( default_visibility = ["//visibility:public"], ) +td_library( + name = "PassesTdFiles", + srcs = ["Passes.td"], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:RewritePassBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "PassesIncGen", + tbl_outs = [ + ( + ["-gen-pass-decls"], + "Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = ":Passes.td", + deps = [ + ":PassesTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + cc_library( name = "codegen", srcs = [ @@ -14,9 +41,13 @@ cc_library( "gen_rust.cpp", "mustache.h", ], - hdrs = ["codegen.h"], + hdrs = [ + "Passes.h", + "codegen.h", + ], data = [":data"], deps = [ + ":PassesIncGen", ":protocol_info_const", "//zirgen/Dialect/ZHLT/IR:Codegen", "//zirgen/Dialect/Zll/Analysis", diff --git a/zirgen/compiler/codegen/Passes.h b/zirgen/compiler/codegen/Passes.h new file mode 100644 index 00000000..9c60382e --- /dev/null +++ b/zirgen/compiler/codegen/Passes.h @@ -0,0 +1,40 @@ +// Copyright 2024 RISC Zero, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/LocationSnapshot.h" +#include "mlir/Transforms/ViewOpGraph.h" +#include "llvm/Support/Debug.h" +#include +#include + +namespace zirgen { + +#define GEN_PASS_DECL_EMITRECURSION +#include "zirgen/compiler/codegen/Passes.h.inc" + +/// Creates a pass which outputs recursion predicates to .zkr files +std::unique_ptr> createEmitRecursionPass(); +std::unique_ptr> +createEmitRecursionPass(llvm::StringRef outputDir); + +#define GEN_PASS_REGISTRATION +#include "zirgen/compiler/codegen/Passes.h.inc" + +} // namespace zirgen diff --git a/zirgen/compiler/codegen/Passes.td b/zirgen/compiler/codegen/Passes.td new file mode 100644 index 00000000..bcbedc2c --- /dev/null +++ b/zirgen/compiler/codegen/Passes.td @@ -0,0 +1,31 @@ +// Copyright 2024 RISC Zero, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ZIRGEN_CODEGEN_PASSSES +#define ZIRGEN_CODEGEN_PASSES + +include "mlir/Pass/PassBase.td" +include "mlir/Rewrite/PassUtil.td" + +def EmitRecursion : Pass<"emit-recursion", "mlir::func::FuncOp"> { + let summary = "Encode a function to a recursion .zkr file, and output it"; + let options = [ + Option<"outputDir", "outputdir", "std::string", /*default=*/"", + "Output directory to write .zkr to">, + ]; + let constructor = "zirgen::createEmitRecursionPass()"; +} + +#endif // ZIRGEN_CODEGEN_PASSES + diff --git a/zirgen/compiler/codegen/codegen.h b/zirgen/compiler/codegen/codegen.h index 9437c7d6..8f735eb6 100644 --- a/zirgen/compiler/codegen/codegen.h +++ b/zirgen/compiler/codegen/codegen.h @@ -19,6 +19,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/PassManager.h" #include "zirgen/Dialect/Zll/IR/Codegen.h" +#include "zirgen/compiler/codegen/Passes.h" #include "zirgen/compiler/codegen/protocol_info_const.h" #include "llvm/Support/ManagedStatic.h" diff --git a/zirgen/compiler/codegen/gen_recursion.cpp b/zirgen/compiler/codegen/gen_recursion.cpp index fbd62c8e..663c6c35 100644 --- a/zirgen/compiler/codegen/gen_recursion.cpp +++ b/zirgen/compiler/codegen/gen_recursion.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "zirgen/compiler/codegen/Passes.h" #include "zirgen/compiler/codegen/codegen.h" #include @@ -22,8 +23,22 @@ using namespace mlir; namespace zirgen { + +#define GEN_PASS_DEF_EMITRECURSION +#include "zirgen/compiler/codegen/Passes.h.inc" + namespace { +class EmitRecursionPass : public impl::EmitRecursionBase { +public: + EmitRecursionPass() = default; + EmitRecursionPass(StringRef dir) { this->outputDir = dir.str(); } + void runOnOperation() override { + recursion::EncodeStats stats; + emitRecursion(outputDir, getOperation(), &stats); + } +}; + std::unique_ptr openOutputFile(const std::string& path, const std::string& name) { std::string filename = path + "/" + name; @@ -56,4 +71,13 @@ void emitRecursion(const std::string& path, func::FuncOp func, recursion::Encode } } +std::unique_ptr> +createEmitRecursionPass(llvm::StringRef dir) { + return std::make_unique(dir); +} + +std::unique_ptr> createEmitRecursionPass() { + return std::make_unique(); +} + } // namespace zirgen diff --git a/zirgen/compiler/edsl/edsl.cpp b/zirgen/compiler/edsl/edsl.cpp index 221e028f..373e7267 100644 --- a/zirgen/compiler/edsl/edsl.cpp +++ b/zirgen/compiler/edsl/edsl.cpp @@ -165,8 +165,7 @@ mlir::Location CaptureVal::getLoc() { return toLoc(loc); } -// TODO: Figure out why we get more crashes in CSE when threading is enabled. -Module::Module() : ctx(MLIRContext::Threading::DISABLED), builder(&ctx) { +Module::Module() : builder(&ctx) { ctx.getOrLoadDialect(); ctx.getOrLoadDialect(); ctx.getOrLoadDialect(); @@ -174,12 +173,17 @@ Module::Module() : ctx(MLIRContext::Threading::DISABLED), builder(&ctx) { builder.setInsertionPointToEnd(&module->getBodyRegion().front()); } -void Module::optimize(size_t stageCount) { +void Module::addOptimizationPasses(PassManager& pm) { sortForReproducibility(); - PassManager pm(module->getContext()); + OpPassManager& opm = pm.nest(); opm.addPass(createCanonicalizerPass()); opm.addPass(createCSEPass()); +} + +void Module::optimize(size_t stageCount) { + PassManager pm(module->getContext()); + addOptimizationPasses(pm); if (failed(pm.run(*module))) { throw std::runtime_error("Failed to apply basic optimization passes"); } diff --git a/zirgen/compiler/edsl/edsl.h b/zirgen/compiler/edsl/edsl.h index 6737f6eb..801acc16 100644 --- a/zirgen/compiler/edsl/edsl.h +++ b/zirgen/compiler/edsl/edsl.h @@ -206,6 +206,7 @@ class Module { // to undo the varation in argument order. void sortForReproducibility(); + void addOptimizationPasses(mlir::PassManager& pm); void optimize(size_t stageCount = 0); void setExternHandler(Zll::ExternHandler* handler); void runFunc(llvm::StringRef name,