Skip to content

Commit

Permalink
Remove fallbackEmitLiteral in favor of literal handler syntax
Browse files Browse the repository at this point in the history
* Migrate more generators to use the literal handlers
* Remove no longer used literal-creation macros
  • Loading branch information
shkoo committed Sep 25, 2024
1 parent e008df2 commit b0d761e
Show file tree
Hide file tree
Showing 21 changed files with 195 additions and 264 deletions.
1 change: 0 additions & 1 deletion zirgen/Dialect/Zll/IR/Codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ struct LanguageSyntax {
virtual void emitClone(CodegenEmitter& cg, CodegenIdent<IdentKind::Var> value);
virtual void emitTakeReference(CodegenEmitter& cg, EmitPart emitTarget);

virtual void fallbackEmitLiteral(CodegenEmitter& cg, mlir::Type ty, mlir::Attribute value) = 0;
virtual void emitFuncDefinition(CodegenEmitter& cg,
CodegenIdent<IdentKind::Func> funcName,
llvm::ArrayRef<CodegenIdent<IdentKind::Var>> argNames,
Expand Down
4 changes: 3 additions & 1 deletion zirgen/Dialect/Zll/IR/CodegenEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,9 @@ void CodegenEmitter::emitLiteral(mlir::Type ty, mlir::Attribute value) {
if (succeeded(codegenType.emitLiteral(*this, value)))
return;
}
opts.lang->fallbackEmitLiteral(*this, ty, value);
llvm::errs() << "Don't know how to emit type " << ty << " with value " << value
<< " (name = " << value.getAbstractAttribute().getName() << ")\n";
abort();
}

void CodegenEmitter::emitConstDef(CodegenIdent<IdentKind::Const> name, CodegenValue value) {
Expand Down
16 changes: 0 additions & 16 deletions zirgen/Dialect/Zll/IR/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,6 @@ ValType::getTypeName(codegen::CodegenEmitter& cg) const {
}
}

mlir::LogicalResult ValType::emitLiteral(zirgen::codegen::CodegenEmitter& cg,
mlir::Attribute attr) const {
// Only emit a literal if it's a field element.
auto arrayAttr = llvm::dyn_cast<PolynomialAttr>(attr);
if (!arrayAttr)
return failure();

llvm::SmallVector<codegen::EmitPart> macroParts;
llvm::append_range(macroParts, arrayAttr.asArrayRef());
if (macroParts.size() == 1)
cg.emitInvokeMacro(cg.getStringAttr("makeVal"), macroParts);
else
cg.emitInvokeMacro(cg.getStringAttr("makeValExt"), macroParts);
return success();
}

ExtensionField ValType::getExtensionField() const {
if (getExtended())
return getField().getExtExtensionField();
Expand Down
2 changes: 1 addition & 1 deletion zirgen/Dialect/Zll/IR/Types.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ZllType<string name, string typeMnemonic, list<Trait> traits = []>
let mnemonic = typeMnemonic;
}

def Val : ZllType<"Val", "val", [DeclareTypeInterfaceMethods<CodegenTypeInterface, ["getTypeName", "emitLiteral"]>]> {
def Val : ZllType<"Val", "val", [DeclareTypeInterfaceMethods<CodegenTypeInterface, ["getTypeName"]>]> {
let summary = "An expression which results in a single field element";
let parameters = (ins
"::zirgen::Zll::FieldAttr": $field,
Expand Down
8 changes: 4 additions & 4 deletions zirgen/Dialect/Zll/IR/test/emit-codegen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ func.func @add_with_0(%arg : !zll.val<BabyBear>) -> !zll.val<BabyBear> {
%0 = zll.const 0
%1 = zll.add %0:<BabyBear>, %arg:<BabyBear>
%2 = zll.isz %1:<BabyBear>
// CPP-CHECK: Val {{.*}} = isz((MAKE_VAL(0) + arg0))
// RUST-CHECK: let x1 : Val = isz((make_val!(0) + arg0))
// CPP-CHECK: Val {{.*}} = isz((Val(0) + arg0))
// RUST-CHECK: let x1 : Val = isz((Val::new(0) + arg0))
zll.if %2 : <BabyBear> {
// CPP-CHECK: if (to_size_t(x1)) {
// RUST-CHECK: if is_nonzero(x1) {
%three = zll.const 3
zll.eqz %three : <BabyBear>
// CPP-CHECK: EQZ(MAKE_VAL(3), "Dialect/Zll/IR/test/emit-codegen.mlir:16")
// RUST-CHECK: eqz!(make_val!(3), "Dialect/Zll/IR/test/emit-codegen.mlir:16")
// CPP-CHECK: EQZ(Val(3), "Dialect/Zll/IR/test/emit-codegen.mlir:16")
// RUST-CHECK: eqz!(Val::new(3), "Dialect/Zll/IR/test/emit-codegen.mlir:16")
}
// CHECK: }
return %2: !zll.val<BabyBear>
Expand Down
46 changes: 18 additions & 28 deletions zirgen/Main/gen_zirgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,7 @@ std::unique_ptr<llvm::raw_ostream> openOutput(StringRef filename) {
return ofs;
}

void emitDefs(ModuleOp mod, codegen::LanguageSyntax* lang, StringRef filename) {
codegen::CodegenOptions opts;
opts.lang = lang;

void emitDefs(ModuleOp mod, const codegen::CodegenOptions& opts, StringRef filename) {
auto os = openOutput(filename);
zirgen::codegen::CodegenEmitter emitter(opts, os.get(), mod.getContext());
auto emitZhlt = Zhlt::getEmitter(mod, emitter);
Expand All @@ -88,20 +85,14 @@ void emitDefs(ModuleOp mod, codegen::LanguageSyntax* lang, StringRef filename) {
}
}

void emitTypes(ModuleOp mod, codegen::LanguageSyntax* lang, StringRef filename) {
codegen::CodegenOptions opts;
opts.lang = lang;

void emitTypes(ModuleOp mod, const codegen::CodegenOptions& opts, StringRef filename) {
auto os = openOutput(filename);
zirgen::codegen::CodegenEmitter emitter(opts, os.get(), mod.getContext());
emitter.emitTypeDefs(mod);
}

template <typename... OpT>
void emitOps(ModuleOp mod, codegen::LanguageSyntax* lang, StringRef filename) {
codegen::CodegenOptions opts;
opts.lang = lang;

void emitOps(ModuleOp mod, const codegen::CodegenOptions& opts, StringRef filename) {
auto os = openOutput(filename);
zirgen::codegen::CodegenEmitter emitter(opts, os.get(), mod.getContext());

Expand Down Expand Up @@ -194,22 +185,21 @@ int main(int argc, char* argv[]) {
return 1;
}

static codegen::RustLanguageSyntax kRust;
static codegen::CppLanguageSyntax kCpp;

emitDefs(*typedModule, &kRust, "defs.rs.inc");
emitTypes(*typedModule, &kRust, "types.rs.inc");
emitOps<Zhlt::ValidityRegsFuncOp>(*typedModule, &kRust, "validity_regs.rs.inc");
emitOps<Zhlt::ValidityTapsFuncOp>(*typedModule, &kRust, "validity_taps.rs.inc");
emitOps<ZStruct::GlobalConstOp>(*typedModule, &kRust, "layout.rs.inc");
emitOps<Zhlt::StepFuncOp, Zhlt::ExecFuncOp>(*typedModule, &kRust, "steps.rs.inc");

emitDefs(*typedModule, &kCpp, "defs.cpp.inc");
emitTypes(*typedModule, &kCpp, "types.h.inc");
emitOps<Zhlt::ValidityTapsFuncOp>(*typedModule, &kCpp, "validity_regs.cpp.inc");
emitOps<Zhlt::ValidityRegsFuncOp>(*typedModule, &kCpp, "validity_taps.cpp.inc");
emitOps<ZStruct::GlobalConstOp>(*typedModule, &kCpp, "layout.cpp.inc");
emitOps<Zhlt::StepFuncOp, Zhlt::ExecFuncOp>(*typedModule, &kCpp, "steps.cpp.inc");
auto rustOpts = codegen::getRustCodegenOpts();
emitDefs(*typedModule, rustOpts, "defs.rs.inc");
emitTypes(*typedModule, rustOpts, "types.rs.inc");
emitOps<Zhlt::ValidityRegsFuncOp>(*typedModule, rustOpts, "validity_regs.rs.inc");
emitOps<Zhlt::ValidityTapsFuncOp>(*typedModule, rustOpts, "validity_taps.rs.inc");
emitOps<ZStruct::GlobalConstOp>(*typedModule, rustOpts, "layout.rs.inc");
emitOps<Zhlt::StepFuncOp, Zhlt::ExecFuncOp>(*typedModule, rustOpts, "steps.rs.inc");

auto cppOpts = codegen::getCppCodegenOpts();
emitDefs(*typedModule, cppOpts, "defs.cpp.inc");
emitTypes(*typedModule, cppOpts, "types.h.inc");
emitOps<Zhlt::ValidityTapsFuncOp>(*typedModule, cppOpts, "validity_regs.cpp.inc");
emitOps<Zhlt::ValidityRegsFuncOp>(*typedModule, cppOpts, "validity_taps.cpp.inc");
emitOps<ZStruct::GlobalConstOp>(*typedModule, cppOpts, "layout.cpp.inc");
emitOps<Zhlt::StepFuncOp, Zhlt::ExecFuncOp>(*typedModule, cppOpts, "steps.cpp.inc");

typedModule->print(*openOutput("circuit.ir"));

Expand Down
2 changes: 1 addition & 1 deletion zirgen/bootstrap/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ impl Args {
"src",
"",
);
copy_group(circuit, &src_path, &sys_path, ZIRGEN_SYS_OUTPUTS, "src", "");
copy_group(circuit, &src_path, &sys_path, ZIRGEN_SYS_OUTPUTS, "cxx", "");
// TODO: Improve formatting performance
// cargo_fmt_circuit(circuit, &rust_path, &None);
}
Expand Down
28 changes: 14 additions & 14 deletions zirgen/circuit/bigint/gen_bigint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,12 @@ std::unique_ptr<llvm::raw_fd_ostream> openOutputFile(StringRef path, StringRef n
return ofs;
}

void emitLang(StringRef langName,
zirgen::codegen::LanguageSyntax* lang,
StringRef path,
ModuleOp module) {
void emit(StringRef langName,
const zirgen::codegen::CodegenOptions& codegenOpts,
StringRef path,
ModuleOp module) {
auto ofs = openOutputFile(path, ("bigint." + langName + ".inc").str());

codegen::CodegenOptions codegenOpts;
codegenOpts.lang = lang;
zirgen::codegen::CodegenEmitter cg(codegenOpts, ofs.get(), module.getContext());
cg.emitModule(module);

Expand Down Expand Up @@ -201,15 +199,17 @@ int main(int argc, char* argv[]) {
throw std::runtime_error("Failed to apply basic optimization passes");
}

static codegen::RustLanguageSyntax rustLang;
rustLang.addContextArgument("ctx: &mut BigIntContext");
rustLang.addItemsMacro("bigint_program_info");
rustLang.addItemsMacro("bigint_program_list");
emitLang("rs", &rustLang, outputDir, module.getModule());
auto rustOpts = codegen::getRustCodegenOpts();
auto rustLang = dynamic_cast<codegen::RustLanguageSyntax*>(rustOpts.lang);
assert(rustLang && "expecting getRsutCodegenOpts to use RustLanguage");
rustLang->addContextArgument("ctx: &mut BigIntContext");
rustLang->addItemsMacro("bigint_program_info");
rustLang->addItemsMacro("bigint_program_list");
emit("rs", rustOpts, outputDir, module.getModule());

static codegen::CppLanguageSyntax cppLang;
cppLang.addContextArgument("BigIntContext& ctx");
emitLang("cpp", &cppLang, outputDir, module.getModule());
auto cppOpts = codegen::getRustCodegenOpts();
cppOpts.lang->addContextArgument("BigIntContext& ctx");
emit("cpp", cppOpts, outputDir, module.getModule());

PassManager pm2(module.getCtx());
if (failed(applyPassManagerCLOptions(pm2))) {
Expand Down
18 changes: 0 additions & 18 deletions zirgen/compiler/codegen/CppLanguageSyntax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,6 @@ using namespace zirgen::Zll;

namespace zirgen::codegen {

void CppLanguageSyntax::fallbackEmitLiteral(CodegenEmitter& cg,
mlir::Type ty,
mlir::Attribute value) {
TypeSwitch<Attribute>(value)
.Case<IntegerAttr>([&](auto intAttr) { cg << intAttr.getValue().getZExtValue(); })
.Case<StringAttr>([&](auto strAttr) { cg.emitEscapedString(strAttr); })
.Case<PolynomialAttr>([&](auto polyAttr) {
cg << "(" << ty.cast<ValType>().getTypeName(cg) << " {";
cg.interleaveComma(polyAttr.asArrayRef());
cg << "})";
})
.Default([&](auto) {
llvm::errs() << "Don't know how to emit type " << ty << " into C++ with value " << value
<< "\n";
abort();
});
}

void CppLanguageSyntax::emitConditional(CodegenEmitter& cg,
CodegenValue condition,
EmitPart emitThen) {
Expand Down
13 changes: 0 additions & 13 deletions zirgen/compiler/codegen/RustLanguageSyntax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,6 @@ void RustLanguageSyntax::emitSwitchStatement(CodegenEmitter& cg,
cg << "}";
}

void RustLanguageSyntax::fallbackEmitLiteral(CodegenEmitter& cg,
mlir::Type ty,
mlir::Attribute value) {
TypeSwitch<Attribute>(value)
.Case<IntegerAttr>([&](auto intAttr) { cg << intAttr.getValue().getZExtValue(); })
.Case<StringAttr>([&](auto strAttr) { cg.emitEscapedString(strAttr); })
.Default([&](auto) {
llvm::errs() << "Don't know how to emit type " << ty << " into rust++ with value " << value
<< "\n";
abort();
});
}

void RustLanguageSyntax::emitFuncDefinition(CodegenEmitter& cg,
CodegenIdent<IdentKind::Func> funcName,
llvm::ArrayRef<CodegenIdent<IdentKind::Var>> argNames,
Expand Down
44 changes: 43 additions & 1 deletion zirgen/compiler/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,68 @@ using namespace mlir;
namespace cl = llvm::cl;

namespace zirgen {

namespace codegen {

namespace {

void addCommonSyntax(CodegenOptions& opts) {
opts.addLiteralHandler<IntegerAttr>(
[](CodegenEmitter& cg, auto intAttr) { cg << intAttr.getValue().getZExtValue(); });
opts.addLiteralHandler<StringAttr>(
[](CodegenEmitter& cg, auto strAttr) { cg.emitEscapedString(strAttr); });
}

void addCppSyntax(CodegenOptions& opts) {
opts.addLiteralHandler<PolynomialAttr>([&](CodegenEmitter& cg, auto polyAttr) {
auto elems = polyAttr.asArrayRef();
if (elems.size() == 1) {
cg << "Val(" << elems[0] << ")";
} else {
cg << "Val" << elems.size() << "{";
cg.interleaveComma(elems);
cg << "}";
}
});
}

void addRustSyntax(CodegenOptions& opts) {
opts.addLiteralHandler<PolynomialAttr>([&](CodegenEmitter& cg, auto polyAttr) {
auto elems = polyAttr.asArrayRef();
if (elems.size() == 1) {
cg << "Val::new(" << elems[0] << ")";
} else {
cg << "ExtVal::new(";
cg.interleaveComma(elems);
cg << ")";
}
});
}

} // namespace

CodegenOptions getRustCodegenOpts() {
static codegen::RustLanguageSyntax kRust;
codegen::CodegenOptions opts(&kRust);
addCommonSyntax(opts);
addRustSyntax(opts);
ZStruct::addRustSyntax(opts);
return opts;
}

CodegenOptions getCppCodegenOpts() {
static codegen::CppLanguageSyntax kCpp;
codegen::CodegenOptions opts(&kCpp);
addCommonSyntax(opts);
addCppSyntax(opts);
ZStruct::addCppSyntax(opts);
return opts;
}

CodegenOptions getCudaCodegenOpts() {
static codegen::CudaLanguageSyntax kCuda;
codegen::CodegenOptions opts(&kCuda);
addCommonSyntax(opts);
addCppSyntax(opts);
ZStruct::addCppSyntax(opts);
return opts;
}
Expand Down
2 changes: 0 additions & 2 deletions zirgen/compiler/codegen/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ class RustLanguageSyntax : public LanguageSyntax {
std::string canonIdent(llvm::StringRef ident, IdentKind idt) override;
void emitClone(CodegenEmitter& cg, CodegenIdent<IdentKind::Var> value) override;
void emitTakeReference(CodegenEmitter& cg, EmitPart emitTarget) override;
void fallbackEmitLiteral(CodegenEmitter& cg, mlir::Type ty, mlir::Attribute value) override;

void emitConditional(CodegenEmitter& cg, CodegenValue condition, EmitPart emitThen) override;
void emitSwitchStatement(CodegenEmitter& cg,
Expand Down Expand Up @@ -161,7 +160,6 @@ struct CppLanguageSyntax : public LanguageSyntax {
LanguageKind getLanguageKind() override { return LanguageKind::Cpp; }

std::string canonIdent(llvm::StringRef ident, IdentKind idt) override;
void fallbackEmitLiteral(CodegenEmitter& cg, mlir::Type ty, mlir::Attribute value) override;

void emitConditional(CodegenEmitter& cg, CodegenValue condition, EmitPart emitThen) override;
void emitSwitchStatement(CodegenEmitter& cg,
Expand Down
8 changes: 2 additions & 6 deletions zirgen/compiler/tools/zirgen-translate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ int main(int argc, char** argv) {
"rust-codegen",
"",
[](mlir::ModuleOp module, llvm::raw_ostream& output) {
static codegen::RustLanguageSyntax rust;
codegen::CodegenOptions opts;
opts.lang = &rust;
codegen::CodegenOptions opts = codegen::getRustCodegenOpts();
if (!funcName.empty()) {
auto func = module.lookupSymbol<mlir::func::FuncOp>(funcName);
if (!func) {
Expand All @@ -107,9 +105,7 @@ int main(int argc, char** argv) {
"cpp-codegen",
"",
[](mlir::ModuleOp module, llvm::raw_ostream& output) {
static codegen::CppLanguageSyntax cpp;
codegen::CodegenOptions opts;
opts.lang = &cpp;
codegen::CodegenOptions opts = codegen::getCppCodegenOpts();
if (!funcName.empty()) {
auto func = module.lookupSymbol<mlir::func::FuncOp>(funcName);
if (!func) {
Expand Down
11 changes: 3 additions & 8 deletions zirgen/dsl/driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,9 @@ int main(int argc, char* argv[]) {
}

if (emitAction == Action::PrintRust || emitAction == Action::PrintCpp) {
codegen::CodegenOptions codegenOpts;
static codegen::RustLanguageSyntax kRust;
static codegen::CppLanguageSyntax kCpp;

codegenOpts.lang = (emitAction == Action::PrintRust)
? static_cast<codegen::LanguageSyntax*>(&kRust)
: static_cast<codegen::LanguageSyntax*>(&kCpp);

codegen::CodegenOptions codegenOpts = (emitAction == Action::PrintRust)
? codegen::getRustCodegenOpts()
: codegen::getCppCodegenOpts();
zirgen::codegen::CodegenEmitter emitter(codegenOpts, &llvm::outs(), &context);
if (zirgen::Zhlt::emitModule(*typedModule, emitter).failed()) {
llvm::errs() << "Failed to emit circuit\n";
Expand Down
Loading

0 comments on commit b0d761e

Please sign in to comment.