Skip to content

Commit

Permalink
generate recursion predicates (.zkr) in parallel (#123)
Browse files Browse the repository at this point in the history
* Add an EmitRecursion pass to wrap emitRecursion to generate zkrs that the pass manager can run in parallel
* Enable parallelism in EDSL context
* (Incidental change: update hedron to fix clangd indexing)
  • Loading branch information
shkoo authored Dec 16, 2024
1 parent c641ae8 commit 13f7797
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 10 deletions.
5 changes: 2 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,8 @@ rules_pkg_dependencies()
# tip: use `bazel run @hedron_compile_commands//:refresh_all`
http_archive(
name = "hedron_compile_commands",
sha256 = "f01636585c3fb61c7c2dc74df511217cd5ad16427528ab33bc76bb34535f10a1",
strip_prefix = "bazel-compile-commands-extractor-a14ad3a64e7bf398ab48105aaa0348e032ac87f8",
url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/a14ad3a64e7bf398ab48105aaa0348e032ac87f8.tar.gz",
strip_prefix = "bazel-compile-commands-extractor-4f28899228fb3ad0126897876f147ca15026151e",
url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/4f28899228fb3ad0126897876f147ca15026151e.tar.gz",
)

load("@hedron_compile_commands//:workspace_setup.bzl", "hedron_compile_commands_setup")
Expand Down
13 changes: 11 additions & 2 deletions zirgen/circuit/predicates/gen_predicates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,15 @@ int main(int argc, char* argv[]) {
addUnion(
module, "union", [&](Assumption left, Assumption right) { return unionFunc(left, right); });

module.optimize();
module.getModule().walk([&](mlir::func::FuncOp func) { zirgen::emitRecursion(outputDir, func); });
mlir::PassManager pm(module.getModule()->getContext());
if (failed(applyPassManagerCLOptions(pm))) {
exit(1);
}
module.addOptimizationPasses(pm);
pm.nest<mlir::func::FuncOp>().addPass(createEmitRecursionPass(outputDir));

if (failed(pm.run(module.getModule()))) {
llvm::errs() << "Unable to run recursion pipeline\n";
exit(1);
}
}
33 changes: 32 additions & 1 deletion zirgen/compiler/codegen/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,7 +1,34 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")

package(
default_visibility = ["//visibility:public"],
)

td_library(
name = "PassesTdFiles",
srcs = ["Passes.td"],
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:RewritePassBaseTdFiles",
],
)

gentbl_cc_library(
name = "PassesIncGen",
tbl_outs = [
(
["-gen-pass-decls"],
"Passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = ":Passes.td",
deps = [
":PassesTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)

cc_library(
name = "codegen",
srcs = [
Expand All @@ -14,9 +41,13 @@ cc_library(
"gen_rust.cpp",
"mustache.h",
],
hdrs = ["codegen.h"],
hdrs = [
"Passes.h",
"codegen.h",
],
data = [":data"],
deps = [
":PassesIncGen",
":protocol_info_const",
"//zirgen/Dialect/ZHLT/IR:Codegen",
"//zirgen/Dialect/Zll/Analysis",
Expand Down
40 changes: 40 additions & 0 deletions zirgen/compiler/codegen/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright 2024 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LocationSnapshot.h"
#include "mlir/Transforms/ViewOpGraph.h"
#include "llvm/Support/Debug.h"
#include <limits>
#include <memory>

namespace zirgen {

#define GEN_PASS_DECL_EMITRECURSION
#include "zirgen/compiler/codegen/Passes.h.inc"

/// Creates a pass which outputs recursion predicates to .zkr files
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> createEmitRecursionPass();
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
createEmitRecursionPass(llvm::StringRef outputDir);

#define GEN_PASS_REGISTRATION
#include "zirgen/compiler/codegen/Passes.h.inc"

} // namespace zirgen
31 changes: 31 additions & 0 deletions zirgen/compiler/codegen/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright 2024 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef ZIRGEN_CODEGEN_PASSSES
#define ZIRGEN_CODEGEN_PASSES

include "mlir/Pass/PassBase.td"
include "mlir/Rewrite/PassUtil.td"

def EmitRecursion : Pass<"emit-recursion", "mlir::func::FuncOp"> {
let summary = "Encode a function to a recursion .zkr file, and output it";
let options = [
Option<"outputDir", "outputdir", "std::string", /*default=*/"",
"Output directory to write .zkr to">,
];
let constructor = "zirgen::createEmitRecursionPass()";
}

#endif // ZIRGEN_CODEGEN_PASSES

1 change: 1 addition & 0 deletions zirgen/compiler/codegen/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/PassManager.h"
#include "zirgen/Dialect/Zll/IR/Codegen.h"
#include "zirgen/compiler/codegen/Passes.h"
#include "zirgen/compiler/codegen/protocol_info_const.h"
#include "llvm/Support/ManagedStatic.h"

Expand Down
24 changes: 24 additions & 0 deletions zirgen/compiler/codegen/gen_recursion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "zirgen/compiler/codegen/Passes.h"
#include "zirgen/compiler/codegen/codegen.h"

#include <fstream>
Expand All @@ -22,8 +23,22 @@
using namespace mlir;

namespace zirgen {

#define GEN_PASS_DEF_EMITRECURSION
#include "zirgen/compiler/codegen/Passes.h.inc"

namespace {

class EmitRecursionPass : public impl::EmitRecursionBase<EmitRecursionPass> {
public:
EmitRecursionPass() = default;
EmitRecursionPass(StringRef dir) { this->outputDir = dir.str(); }
void runOnOperation() override {
recursion::EncodeStats stats;
emitRecursion(outputDir, getOperation(), &stats);
}
};

std::unique_ptr<llvm::raw_fd_ostream> openOutputFile(const std::string& path,
const std::string& name) {
std::string filename = path + "/" + name;
Expand Down Expand Up @@ -56,4 +71,13 @@ void emitRecursion(const std::string& path, func::FuncOp func, recursion::Encode
}
}

std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
createEmitRecursionPass(llvm::StringRef dir) {
return std::make_unique<EmitRecursionPass>(dir);
}

std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> createEmitRecursionPass() {
return std::make_unique<EmitRecursionPass>();
}

} // namespace zirgen
12 changes: 8 additions & 4 deletions zirgen/compiler/edsl/edsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,21 +165,25 @@ mlir::Location CaptureVal::getLoc() {
return toLoc(loc);
}

// TODO: Figure out why we get more crashes in CSE when threading is enabled.
Module::Module() : ctx(MLIRContext::Threading::DISABLED), builder(&ctx) {
Module::Module() : builder(&ctx) {
ctx.getOrLoadDialect<ZllDialect>();
ctx.getOrLoadDialect<Iop::IopDialect>();
ctx.getOrLoadDialect<ZStruct::ZStructDialect>();
module = ModuleOp::create(UnknownLoc::get(&ctx));
builder.setInsertionPointToEnd(&module->getBodyRegion().front());
}

void Module::optimize(size_t stageCount) {
void Module::addOptimizationPasses(PassManager& pm) {
sortForReproducibility();
PassManager pm(module->getContext());

OpPassManager& opm = pm.nest<func::FuncOp>();
opm.addPass(createCanonicalizerPass());
opm.addPass(createCSEPass());
}

void Module::optimize(size_t stageCount) {
PassManager pm(module->getContext());
addOptimizationPasses(pm);
if (failed(pm.run(*module))) {
throw std::runtime_error("Failed to apply basic optimization passes");
}
Expand Down
1 change: 1 addition & 0 deletions zirgen/compiler/edsl/edsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class Module {
// to undo the varation in argument order.
void sortForReproducibility();

void addOptimizationPasses(mlir::PassManager& pm);
void optimize(size_t stageCount = 0);
void setExternHandler(Zll::ExternHandler* handler);
void runFunc(llvm::StringRef name,
Expand Down

0 comments on commit 13f7797

Please sign in to comment.