diff --git a/scripts/execute.py b/scripts/execute.py index 4c51350..4448bbc 100644 --- a/scripts/execute.py +++ b/scripts/execute.py @@ -327,18 +327,18 @@ def dfs(curr_dir: str): log_path = os.path.join(output_dir, f"{basename}.log") log_file = open(log_path, "w") - command = ( - f"clang --target=riscv64 -march=rv64imafdc_zba_zbb -w -xc -O3 -S {testcase}.sy" - f" -o {clang_asm_path}" - ) - - exec_result = execute(command, exec_timeout) - log(log_file, command, exec_result) - - if exec_result["returncode"] is None or exec_result["stderr"] != "": - result_md_table += f"| `{testcase}` | 😢 CE |\n" - print(f"\033[33m[ ERROR ] (clang CE)\033[0m {testcase}, see: ", log_path) - continue + # command = ( + # f"clang --target=riscv64 -march=rv64imafdc_zba_zbb -w -xc -Wno-implicit-function-declaration -O3 -S {testcase}.sy" + # f" -o {clang_asm_path}" + # ) + + # exec_result = execute(command, exec_timeout) + # log(log_file, command, exec_result) + + # if exec_result["returncode"] is None or exec_result["stderr"] != "": + # result_md_table += f"| `{testcase}` | 😢 CE |\n" + # print(f"\033[33m[ ERROR ] (clang CE)\033[0m {testcase}, see: ", log_path) + # continue command = ( f"{executable_path} -S " diff --git a/src/backend/riscv64/peephole.rs b/src/backend/riscv64/peephole.rs index 7d2c555..f509f63 100644 --- a/src/backend/riscv64/peephole.rs +++ b/src/backend/riscv64/peephole.rs @@ -1035,7 +1035,7 @@ pub fn remove_redundant_labels(mctx: &mut MContext) -> bool { changed } -pub fn run_peephole(mctx: &mut MContext, config: &LowerConfig) -> bool { +pub fn run_peephole(mctx: &mut MContext, config: &LowerConfig, aggressive: bool) -> bool { let mut runner1 = PeepholeRunner::new(); let mut runner2 = PeepholeRunner::new(); let mut runner3 = PeepholeRunner::new(); @@ -1043,10 +1043,14 @@ pub fn run_peephole(mctx: &mut MContext, config: &LowerConfig) -> bool { runner1.add_rule(li_dce()); runner1.add_rule(remove_identity_move()); runner1.add_rule(remove_redundant_move()); - runner1.add_rule(remove_redundant_move_word()); // aggressive + if aggressive { + runner1.add_rule(remove_redundant_move_word()); + } runner2.add_rule(fuse_cmp_br()); - runner2.add_rule(fuse_fmul_faddfsub()); // aggressive + if aggressive { + runner2.add_rule(fuse_fmul_faddfsub()); + } runner2.add_rule(fuse_sub_br()); runner3.add_rule(fuse_xori_cmp_br()); diff --git a/src/backend/riscv64/schedule.rs b/src/backend/riscv64/schedule.rs index 0945154..eee7223 100644 --- a/src/backend/riscv64/schedule.rs +++ b/src/backend/riscv64/schedule.rs @@ -6,13 +6,7 @@ use std::{ use super::inst::{AluOpRRR, FpuOpRR, FpuOpRRR, FpuOpRRRR, RvInst, RvInstKind}; use crate::{ - backend::{ - inst::{DisplayMInst, MInst}, - LowerConfig, - MBlock, - MContext, - MFunc, - }, + backend::{inst::MInst, LowerConfig, MBlock, MContext, MFunc}, collections::linked_list::{LinkedListContainerPtr, LinkedListNodePtr}, }; @@ -882,7 +876,7 @@ fn shedule_chunk(mctx: &mut MContext, start: RvInst, end: RvInst, config s += 1; } - println!("scheduling: {}, s: {}", n.display(mctx), s); + // println!("scheduling: {}, s: {}", n.display(mctx), s); scheduled.insert(n, s); diff --git a/src/bin/compiler.rs b/src/bin/compiler.rs index b469f2f..3d2ba8e 100644 --- a/src/bin/compiler.rs +++ b/src/bin/compiler.rs @@ -28,18 +28,31 @@ use orzcc::{ global_dce::{GlobalDce, GLOBAL_DCE}, gvn::{GlobalValueNumbering, GVN}, inline::{Inline, INLINE}, - instcombine::{InstCombine, INSTCOMBINE}, + instcombine::{ + AdvancedInstcombine, + AggressiveInstcombine, + Instcombine, + ADVANCED_INSTCOMBINE, + AGGRESSIVE_INSTCOMBINE, + INSTCOMBINE, + }, legalize::{Legalize, LEGALIZE}, loops::{ DeadLoopElim, + IndvarOffset, + IndvarReduce, IndvarSimplify, Lcssa, LoopPeel, LoopSimplify, + LoopStrengthReduction, LoopUnroll, DEAD_LOOP_ELIM, + INDVAR_OFFSET, + INDVAR_REDUCE, INDVAR_SIMPLIFY, LOOP_PEEL, + LOOP_STRENGTH_REDUCTION, LOOP_UNROLL, }, mem2reg::{Mem2reg, MEM2REG}, @@ -65,7 +78,8 @@ struct CliCommand { emit_vcode: Option, /// Optimization level opt: u8, - + /// If aggressive optimizations are enabled + aggressive: bool, /// Lower config lower_cfg: LowerConfig, } @@ -100,117 +114,207 @@ fn main() -> Result<(), Box> { if cmd.opt > 0 { let mut pipe_basic = Pipeline::default(); + { + pipe_basic.add_pass(GLOBAL2LOCAL); + pipe_basic.add_pass(GLOBAL_DCE); + + pipe_basic.add_pass(MEM2REG); + pipe_basic.add_pass(ELIM_CONSTANT_PHI); + pipe_basic.add_pass(SIMPLE_DCE); + pipe_basic.add_pass(CFG_SIMPLIFY); + + pipe_basic.add_pass(CONSTANT_FOLDING); + pipe_basic.add_pass(ELIM_CONSTANT_PHI); + pipe_basic.add_pass(SIMPLE_DCE); + pipe_basic.add_pass(CFG_SIMPLIFY); + + pipe_basic.add_pass(INSTCOMBINE); + pipe_basic.add_pass(ELIM_CONSTANT_PHI); + pipe_basic.add_pass(SIMPLE_DCE); + pipe_basic.add_pass(CFG_SIMPLIFY); + + pipe_basic.add_pass(GCM); + pipe_basic.add_pass(ELIM_CONSTANT_PHI); + pipe_basic.add_pass(SIMPLE_DCE); + pipe_basic.add_pass(CFG_SIMPLIFY); + pipe_basic.add_pass(BRANCH_CONDITION_SINK); + + pipe_basic.add_pass(BRANCH2SELECT); + pipe_basic.add_pass(ELIM_CONSTANT_PHI); + pipe_basic.add_pass(SIMPLE_DCE); + pipe_basic.add_pass(CFG_SIMPLIFY); + } - pipe_basic.add_pass(GLOBAL2LOCAL); - pipe_basic.add_pass(SIMPLE_DCE); - pipe_basic.add_pass(GLOBAL_DCE); - pipe_basic.add_pass(MEM2REG); - pipe_basic.add_pass(SIMPLE_DCE); - pipe_basic.add_pass(CFG_SIMPLIFY); - pipe_basic.add_pass(CONSTANT_FOLDING); - pipe_basic.add_pass(SIMPLE_DCE); - pipe_basic.add_pass(INSTCOMBINE); - pipe_basic.add_pass(SIMPLE_DCE); - pipe_basic.add_pass(GCM); - pipe_basic.add_pass(BRANCH_CONDITION_SINK); - pipe_basic.add_pass(GVN); - pipe_basic.add_pass(CFG_SIMPLIFY); - pipe_basic.add_pass(ELIM_CONSTANT_PHI); - pipe_basic.add_pass(SIMPLE_DCE); - pipe_basic.add_pass(CFG_SIMPLIFY); - pipe_basic.add_pass(BRANCH2SELECT); - pipe_basic.add_pass(CFG_SIMPLIFY); + // basic + gvn, run after legalization. + let mut pipe_gvn = Pipeline::default(); + { + pipe_gvn.add_pass(GLOBAL2LOCAL); + pipe_gvn.add_pass(GLOBAL_DCE); + + pipe_gvn.add_pass(MEM2REG); + pipe_gvn.add_pass(ELIM_CONSTANT_PHI); + pipe_gvn.add_pass(SIMPLE_DCE); + pipe_gvn.add_pass(CFG_SIMPLIFY); + + pipe_gvn.add_pass(CONSTANT_FOLDING); + pipe_gvn.add_pass(ELIM_CONSTANT_PHI); + pipe_gvn.add_pass(SIMPLE_DCE); + pipe_gvn.add_pass(CFG_SIMPLIFY); + + pipe_gvn.add_pass(INSTCOMBINE); + pipe_gvn.add_pass(ELIM_CONSTANT_PHI); + pipe_gvn.add_pass(SIMPLE_DCE); + pipe_gvn.add_pass(CFG_SIMPLIFY); + + pipe_gvn.add_pass(GCM); + pipe_gvn.add_pass(ELIM_CONSTANT_PHI); + pipe_gvn.add_pass(SIMPLE_DCE); + pipe_gvn.add_pass(CFG_SIMPLIFY); + pipe_gvn.add_pass(BRANCH_CONDITION_SINK); + + pipe_gvn.add_pass(BRANCH2SELECT); + pipe_gvn.add_pass(ELIM_CONSTANT_PHI); + pipe_gvn.add_pass(SIMPLE_DCE); + pipe_gvn.add_pass(CFG_SIMPLIFY); + + pipe_gvn.add_pass(GVN); + pipe_gvn.add_pass(ELIM_CONSTANT_PHI); + pipe_gvn.add_pass(SIMPLE_DCE); + pipe_gvn.add_pass(CFG_SIMPLIFY); + } - let mut pipe_inline = Pipeline::default(); + let mut pipe_tco = Pipeline::default(); + { + pipe_tco.add_pass(TCO); + pipe_tco.add_pass(ELIM_CONSTANT_PHI); + pipe_tco.add_pass(SIMPLE_DCE); + pipe_tco.add_pass(CFG_SIMPLIFY); + } - pipe_inline.add_pass(INLINE); - pipe_inline.add_pass(CFG_SIMPLIFY); - pipe_inline.add_pass(SIMPLE_DCE); - pipe_inline.add_pass(GLOBAL_DCE); + let mut pipe_inline = Pipeline::default(); + { + pipe_inline.add_pass(INLINE); + pipe_inline.add_pass(ELIM_CONSTANT_PHI); + pipe_inline.add_pass(SIMPLE_DCE); + pipe_inline.add_pass(CFG_SIMPLIFY); + pipe_inline.add_pass(GLOBAL_DCE); + } let mut pipe_unroll = Pipeline::default(); + { + pipe_unroll.add_pass(LOOP_UNROLL); + pipe_unroll.add_pass(CONSTANT_FOLDING); + pipe_unroll.add_pass(ELIM_CONSTANT_PHI); + pipe_unroll.add_pass(SIMPLE_DCE); + pipe_unroll.add_pass(CFG_SIMPLIFY); + } - pipe_unroll.add_pass(LOOP_UNROLL); - pipe_unroll.add_pass(CONSTANT_FOLDING); - pipe_unroll.add_pass(CFG_SIMPLIFY); - pipe_unroll.add_pass(SIMPLE_DCE); - - passman.run_transform(GLOBAL2LOCAL, &mut ir, 32); - passman.run_transform(SIMPLE_DCE, &mut ir, 32); - passman.run_transform(GLOBAL_DCE, &mut ir, 32); - passman.run_transform(MEM2REG, &mut ir, 32); - passman.run_transform(SIMPLE_DCE, &mut ir, 32); - passman.run_transform(CFG_SIMPLIFY, &mut ir, 32); - passman.run_transform(CONSTANT_FOLDING, &mut ir, 32); - passman.run_transform(SIMPLE_DCE, &mut ir, 32); - passman.run_transform(INSTCOMBINE, &mut ir, 32); - passman.run_transform(SIMPLE_DCE, &mut ir, 32); - - passman.run_transform(TCO, &mut ir, 1); - passman.run_transform(CFG_SIMPLIFY, &mut ir, 32); - - passman.run_transform(LOOP_PEEL, &mut ir, 1); - passman.run_transform(ELIM_CONSTANT_PHI, &mut ir, 32); - passman.run_transform(SIMPLE_DCE, &mut ir, 32); - passman.run_transform(GCM, &mut ir, 32); - passman.run_transform(BRANCH_CONDITION_SINK, &mut ir, 1); - passman.run_transform(INDVAR_SIMPLIFY, &mut ir, 1); - passman.run_transform(CONSTANT_FOLDING, &mut ir, 32); - passman.run_transform(CFG_SIMPLIFY, &mut ir, 32); - passman.run_transform(SIMPLE_DCE, &mut ir, 32); - passman.run_transform(ELIM_CONSTANT_PHI, &mut ir, 32); - passman.run_transform(SIMPLE_DCE, &mut ir, 32); - passman.run_transform(DEAD_LOOP_ELIM, &mut ir, 1); - passman.run_transform(SIMPLE_DCE, &mut ir, 32); - passman.run_transform(CFG_SIMPLIFY, &mut ir, 32); - passman.run_transform(CONSTANT_FOLDING, &mut ir, 32); - passman.run_transform(SIMPLE_DCE, &mut ir, 32); - passman.run_transform(ELIM_CONSTANT_PHI, &mut ir, 32); - passman.run_transform(SIMPLE_DCE, &mut ir, 32); - - passman.run_transform(INLINE, &mut ir, 1); - passman.run_transform(CONSTANT_FOLDING, &mut ir, 32); - passman.run_transform(CFG_SIMPLIFY, &mut ir, 32); - passman.run_transform(SIMPLE_DCE, &mut ir, 32); - passman.run_transform(GLOBAL_DCE, &mut ir, 32); - - // TODO: unroll earlier to combine load/store - passman.run_transform(LOOP_UNROLL, &mut ir, 2); - passman.run_transform(GCM, &mut ir, 32); - passman.run_transform(BRANCH_CONDITION_SINK, &mut ir, 1); - passman.run_transform(CONSTANT_FOLDING, &mut ir, 32); - passman.run_transform(CFG_SIMPLIFY, &mut ir, 32); - passman.run_transform(SIMPLE_DCE, &mut ir, 32); + // initial pipelines, remove redundant code and simplify the control flow. + { + passman.run_pipeline(&mut ir, &pipe_basic, 32, 8); - passman.run_transform(LEGALIZE, &mut ir, 1); + passman.run_pipeline(&mut ir, &pipe_tco, 32, 8); + passman.run_pipeline(&mut ir, &pipe_basic, 32, 8); + + passman.run_pipeline(&mut ir, &pipe_inline, 32, 8); + passman.run_pipeline(&mut ir, &pipe_basic, 32, 8); + } + + // legalize to remove high level operations. + { + passman.run_transform(LEGALIZE, &mut ir, 1); + passman.run_pipeline(&mut ir, &pipe_gvn, 32, 8); + } + + // reduce the strength of operations with the loop, especially multiplication + // with indvars. + { + // remove redundant induction variables. + let iter = passman.run_transform(INDVAR_REDUCE, &mut ir, 32); + println!("indvar-reduce iterations: {}", iter); + passman.run_pipeline(&mut ir, &pipe_gvn, 32, 8); + + let iter = passman.run_transform(LOOP_STRENGTH_REDUCTION, &mut ir, 32); + println!("loop strength reduction iterations: {}", iter); + passman.run_pipeline(&mut ir, &pipe_gvn, 32, 8); + } + + // loop peeling to remove inefficient inner loop patterns. + { + // loop-peeling tend to eliminate the inner loops that will only be executed in + // the first trip of outer loop. + passman.run_transform(LOOP_PEEL, &mut ir, 1); + // the control indvar of the inner loop will be simplified, and passed to the + // original loop, as the new init. + passman.run_transform(INDVAR_SIMPLIFY, &mut ir, 1); + // there might be nested argument passing, we did not detect that in + // `dead-loop-elim`, but just regard it as constant phi. + passman.run_transform(ELIM_CONSTANT_PHI, &mut ir, 32); + // remove redundant inner loops. + passman.run_transform(DEAD_LOOP_ELIM, &mut ir, 1); + // remove all redundant code. + passman.run_pipeline(&mut ir, &pipe_gvn, 32, 8); + } + // iterate several times, seeking more opportunities. for i in 0..4 { println!("Round {}", i); - let iter = passman.run_pipeline(&mut ir, &pipe_basic, 32, 8); - println!("pipeline basic iterations: {}", iter); + // aggressive dce is a little expensive, run once per round + passman.run_transform(ADCE, &mut ir, 1); + + let iter = passman.run_pipeline(&mut ir, &pipe_gvn, 32, 8); + println!("pipeline gvn iterations: {}", iter); let iter = passman.run_pipeline(&mut ir, &pipe_inline, 32, 8); println!("pipeline inline iterations: {}", iter); + let iter = passman.run_pipeline(&mut ir, &pipe_gvn, 32, 8); + println!("pipeline gvn iterations: {}", iter); + let iter = passman.run_pipeline(&mut ir, &pipe_unroll, 1, 1); println!("pipeline unroll iterations: {}", iter); - // a little expensive, run once per round + let iter = passman.run_pipeline(&mut ir, &pipe_gvn, 32, 8); + println!("pipeline gvn iterations: {}", iter); + } + + // optimize address generation inside loops. + { + // induce offset instructions that uses induction variables. This is placed here + // because there can be complex alias problem after inducing the offset + // instructions. + passman.run_transform(INDVAR_OFFSET, &mut ir, 32); + passman.run_pipeline(&mut ir, &pipe_gvn, 32, 8); + + // aggressive dce will remove redundant indvars after `indvar-offset` passman.run_transform(ADCE, &mut ir, 1); - passman.run_transform(CFG_SIMPLIFY, &mut ir, 32); + passman.run_pipeline(&mut ir, &pipe_gvn, 32, 8); } + // reorder after loop unrolling. passman.run_transform(BLOCK_REORDER, &mut ir, 1); + // TODO: refactor everything below. passman.run_transform(BOOL2COND, &mut ir, 32); - passman.run_transform(CFG_SIMPLIFY, &mut ir, 32); + passman.run_pipeline(&mut ir, &pipe_gvn, 32, 8); + for i in 0..4 { println!("Second Round {}", i); - let iter = passman.run_pipeline(&mut ir, &pipe_basic, 32, 8); - println!("pipeline basic iterations: {}", iter); + + passman.run_transform(ADVANCED_INSTCOMBINE, &mut ir, 32); + + let iter = passman.run_pipeline(&mut ir, &pipe_gvn, 32, 8); + println!("pipeline gvn iterations: {}", iter); + passman.run_transform(ADCE, &mut ir, 1); - passman.run_transform(CFG_SIMPLIFY, &mut ir, 32); + + let iter = passman.run_pipeline(&mut ir, &pipe_gvn, 32, 8); + println!("pipeline gvn iterations: {}", iter); + + if cmd.aggressive { + passman.run_transform(AGGRESSIVE_INSTCOMBINE, &mut ir, 32); + } } } else { passman.run_transform(LEGALIZE, &mut ir, 1); @@ -230,7 +334,7 @@ fn main() -> Result<(), Box> { lower_ctx.lower(); if cmd.opt > 0 { - riscv64::run_peephole(lower_ctx.mctx_mut(), &cmd.lower_cfg); + riscv64::run_peephole(lower_ctx.mctx_mut(), &cmd.lower_cfg, cmd.aggressive); SimplifyCfg::run(lower_ctx.mctx_mut(), &cmd.lower_cfg); RegisterCoalescing::run::(&mut lower_ctx, &cmd.lower_cfg); schedule(lower_ctx.mctx_mut(), &cmd.lower_cfg, Some(128)); @@ -268,8 +372,12 @@ fn register_passes(passman: &mut PassManager) { Mem2reg::register(passman); SimpleDce::register(passman); Adce::register(passman); + ConstantFolding::register(passman); - InstCombine::register(passman); + Instcombine::register(passman); + AdvancedInstcombine::register(passman); + AggressiveInstcombine::register(passman); + ElimConstantPhi::register(passman); Branch2Select::register(passman); Bool2Cond::register(passman); @@ -286,6 +394,9 @@ fn register_passes(passman: &mut PassManager) { LoopPeel::register(passman); IndvarSimplify::register(passman); DeadLoopElim::register(passman); + LoopStrengthReduction::register(passman); + IndvarOffset::register(passman); + IndvarReduce::register(passman); GlobalValueNumbering::register(passman); Gcm::register(passman); @@ -346,6 +457,11 @@ fn cli(passman: &mut PassManager) -> Command { .long("no-omit-frame-pointer") .action(clap::ArgAction::Count), ) + .arg( + Arg::new("aggressive") + .long("aggressive") + .action(clap::ArgAction::Count), + ) .args(passman.get_cli_args()) } @@ -391,6 +507,8 @@ fn parse_args(passman: &mut PassManager) -> CliCommand { combine_stack_adjustments, }; + let aggressive = matches.get_count("aggressive") > 0; + CliCommand { output, source, @@ -399,6 +517,7 @@ fn parse_args(passman: &mut PassManager) -> CliCommand { emit_ir, emit_vcode, opt, + aggressive, lower_cfg, } } diff --git a/src/ir/fold.rs b/src/ir/fold.rs index 98d5cd0..0373c01 100644 --- a/src/ir/fold.rs +++ b/src/ir/fold.rs @@ -36,6 +36,8 @@ impl FoldedConstant { panic!("unwrap_float: not a float constant"); } } + + pub fn is_undef(&self) -> bool { matches!(self, FoldedConstant::Undef) } } /// The context of the constant folding. @@ -62,8 +64,14 @@ impl FoldContext { impl Inst { /// Fold the instruction with a given constant folding context. - pub fn fold(self, ctx: &Context, fold_ctx: &mut FoldContext) -> Option { + pub fn fold( + self, + ctx: &Context, + fold_ctx: &mut FoldContext, + aggressive: bool, + ) -> Option { match self.kind(ctx) { + // XXX: SysY undefined value is not defined accurately. InstKind::Undef => Some(FoldedConstant::Undef), InstKind::IConst(value) => { let width = self.result(ctx, 0).ty(ctx).bitwidth(ctx); @@ -301,8 +309,8 @@ impl Inst { mut_int_val.signext(target_width); Some(FoldedConstant::Integer(mut_int_val)) } - CastOp::UiToFp => { - // FIXME: 在38_light2d中,存在cast int 100000006 to float + CastOp::UiToFp if aggressive => { + // AGGRESSIVE: 在38_light2d中,存在cast int 100000006 to float // 100000010的情况。 let target_ty = self.result(ctx, 0).ty(ctx); let u64_val: u64 = mut_int_val.into(); @@ -315,7 +323,7 @@ impl Inst { }; Some(FoldedConstant::Float(float_val)) } - CastOp::SiToFp => { + CastOp::SiToFp if aggressive => { let target_ty = self.result(ctx, 0).ty(ctx); let u64_val: u64 = mut_int_val.into(); let i64_val: i64 = @@ -334,6 +342,8 @@ impl Inst { | CastOp::FpToSi | CastOp::Bitcast | CastOp::FpExt + | CastOp::UiToFp + | CastOp::SiToFp | CastOp::PtrToInt | CastOp::IntToPtr => None, } diff --git a/src/ir/inst.rs b/src/ir/inst.rs index ac86941..d375f38 100644 --- a/src/ir/inst.rs +++ b/src/ir/inst.rs @@ -1495,7 +1495,8 @@ impl Inst { } /// AGGRESSIVE: FBinary associations can result in precision and rounding - /// errors + /// errors, and integer associations may lead to overflow and undefined + /// behaviors. pub fn is_associative(self, ctx: &Context) -> bool { use InstKind as Ik; @@ -1506,9 +1507,10 @@ impl Inst { | Ik::IBinary(IBinaryOp::And) | Ik::IBinary(IBinaryOp::Or) | Ik::IBinary(IBinaryOp::Xor) - | Ik::IBinary(IBinaryOp::Min) // TODO: aggressive option - | Ik::IBinary(IBinaryOp::Max) /* | Ik::FBinary(FBinaryOp::Add) // aggressive - * | Ik::FBinary(FBinaryOp::Mul) // aggressive */ + | Ik::IBinary(IBinaryOp::Min) + | Ik::IBinary(IBinaryOp::Max) + | Ik::FBinary(FBinaryOp::Add) // aggressive + | Ik::FBinary(FBinaryOp::Mul) // aggressive ) } diff --git a/src/ir/passes/constant_phi.rs b/src/ir/passes/constant_phi.rs index 2e7d955..bf3ea26 100644 --- a/src/ir/passes/constant_phi.rs +++ b/src/ir/passes/constant_phi.rs @@ -1,5 +1,6 @@ use std::collections::{HashMap, HashSet}; +use super::control_flow::CfgCanonicalize; use crate::{ collections::linked_list::LinkedListContainerPtr, ir::{ @@ -71,6 +72,10 @@ impl GlobalPassMut for ElimConstantPhi { impl TransformPass for ElimConstantPhi { fn register(passman: &mut PassManager) { - passman.register_transform(ELIM_CONSTANT_PHI, ElimConstantPhi, Vec::new()); + passman.register_transform( + ELIM_CONSTANT_PHI, + ElimConstantPhi, + vec![Box::new(CfgCanonicalize)], + ); } } diff --git a/src/ir/passes/fold.rs b/src/ir/passes/fold.rs index a41d5ee..7040628 100644 --- a/src/ir/passes/fold.rs +++ b/src/ir/passes/fold.rs @@ -32,7 +32,7 @@ impl LocalPassMut for ConstantFolding { if inst.results(ctx).len() != 1 { continue; } - if let Some(constant) = inst.fold(ctx, &mut self.fold_ctx) { + if let Some(constant) = inst.fold(ctx, &mut self.fold_ctx, false) { let value = inst.result(ctx, 0); self.fold_ctx.set(value, constant); folded_insts.push(inst); diff --git a/src/ir/passes/instcombine.rs b/src/ir/passes/instcombine.rs index 68fe5ff..de64731 100644 --- a/src/ir/passes/instcombine.rs +++ b/src/ir/passes/instcombine.rs @@ -51,24 +51,6 @@ //! //! [1]: where @ is a commutative binary operator. -// TODO: Now this pass only contains a few simple rules, we need to add more -// rules to make it more powerful. -// -// TODO: Some rules **MIGHT** be applicable to floating-point instructions. -// -// TODO: Some simplification might extend the liverange of the value, which can -// potentially increase the register pressure. Maybe we need a `sink` pass to -// sink the instructions. -// -// TODO: We are not sure about the sequence of the rules. Theoretically, -// because of the iterative feature of the pass manager, the sequence of the -// rules should not matter, but we need to test it. -// -// TODO: There are aggressive rules in this pass, maybe we should separate them -// from the non-aggressive ones. -// -// TODO: Find a way to test these rules one by one. - use crate::{ collections::linked_list::{LinkedListContainerPtr, LinkedListNodePtr}, ir::{ @@ -88,6 +70,10 @@ use crate::{ pub const INSTCOMBINE: &str = "instcombine"; +pub const AGGRESSIVE_INSTCOMBINE: &str = "aggressive-instcombine"; + +pub const ADVANCED_INSTCOMBINE: &str = "advanced-instcombine"; + /// A rule for instcombine. struct Rule { /// The rewriter function. @@ -97,42 +83,43 @@ struct Rule { rewriter: fn(&mut Context, Inst) -> bool, } -pub struct InstCombine { +pub struct Instcombine { + rules: Vec, +} + +pub struct AggressiveInstcombine { + rules: Vec, +} + +pub struct AdvancedInstcombine { rules: Vec, } -impl Default for InstCombine { +impl Default for Instcombine { fn default() -> Self { Self { rules: vec![ - mv_same_together(), // aggressive - sub_identity_to_zero(), // aggressive + mv_const_rhs(), mul_zero_elim(), mul_one_elim(), - mv_const_rhs(), add_zero_elim(), assoc_sub_zero(), sub_zero_elim(), offset_zero_elim(), add_to_mul(), - mul_to_shl(), - assoc_const(), // aggressive - distributive_one(), // aggressive - distributive(), // aggressive div_one_elim(), div_neg_one_elim(), - div_to_shift(), rem_one_elim(), - rem_to_shift(), - div_rem_to_mul(), shl_zero_elim(), // not tested shr_zero_elim(), // not tested + redistribute_const(), // aggressive + // reassociate() // ], } } } -impl LocalPassMut for InstCombine { +impl LocalPassMut for Instcombine { type Output = (); fn run(&mut self, ctx: &mut Context, func: Func) -> PassResult<(Self::Output, bool)> { @@ -164,7 +151,7 @@ impl LocalPassMut for InstCombine { } } -impl GlobalPassMut for InstCombine { +impl GlobalPassMut for Instcombine { type Output = (); fn run(&mut self, ctx: &mut Context) -> PassResult<(Self::Output, bool)> { @@ -179,13 +166,148 @@ impl GlobalPassMut for InstCombine { } } -impl TransformPass for InstCombine { +impl TransformPass for Instcombine { fn register(passman: &mut crate::ir::passman::PassManager) { let pass = Self::default(); passman.register_transform(INSTCOMBINE, pass, Vec::new()); } } +impl Default for AdvancedInstcombine { + fn default() -> Self { + Self { + rules: vec![ + div_to_shift(), + rem_to_shift(), + div_rem_to_mul(), + mul_to_shl(), + ], + } + } +} + +impl LocalPassMut for AdvancedInstcombine { + type Output = (); + + fn run(&mut self, ctx: &mut Context, func: Func) -> PassResult<(Self::Output, bool)> { + let mut changed = false; + + let mut cursor = func.cursor(); + while let Some(block) = cursor.next(ctx) { + let mut cursor = block.cursor(); + while let Some(inst) = cursor.next(ctx) { + if !inst.is_used(ctx) { + // if the instruction's results have no users, we can move to the next + // instruction. + continue; + } + for rule in &self.rules { + if (rule.rewriter)(ctx, inst) { + changed = true; + if !inst.is_used(ctx) { + // if the instruction's results have no users, we can move to the next + // instruction without applying other rules + break; + } + } + } + } + } + + Ok(((), changed)) + } +} + +impl GlobalPassMut for AdvancedInstcombine { + type Output = (); + + fn run(&mut self, ctx: &mut Context) -> PassResult<(Self::Output, bool)> { + let mut changed = false; + + for func in ctx.funcs() { + let ((), local_changed) = LocalPassMut::run(self, ctx, func)?; + changed |= local_changed; + } + + Ok(((), changed)) + } +} + +impl TransformPass for AdvancedInstcombine { + fn register(passman: &mut crate::ir::passman::PassManager) { + let pass = Self::default(); + passman.register_transform(ADVANCED_INSTCOMBINE, pass, Vec::new()); + } +} + +impl Default for AggressiveInstcombine { + fn default() -> Self { + Self { + rules: vec![ + mv_same_together(), // aggressive + sub_identity_to_zero(), // aggressive + assoc_const(), // aggressive + distributive_one(), // aggressive + distributive(), // aggressive + ], + } + } +} + +impl LocalPassMut for AggressiveInstcombine { + type Output = (); + + fn run(&mut self, ctx: &mut Context, func: Func) -> PassResult<(Self::Output, bool)> { + let mut changed = false; + + let mut cursor = func.cursor(); + while let Some(block) = cursor.next(ctx) { + let mut cursor = block.cursor(); + while let Some(inst) = cursor.next(ctx) { + if !inst.is_used(ctx) { + // if the instruction's results have no users, we can move to the next + // instruction. + continue; + } + for rule in &self.rules { + if (rule.rewriter)(ctx, inst) { + changed = true; + if !inst.is_used(ctx) { + // if the instruction's results have no users, we can move to the next + // instruction without applying other rules + break; + } + } + } + } + } + + Ok(((), changed)) + } +} + +impl GlobalPassMut for AggressiveInstcombine { + type Output = (); + + fn run(&mut self, ctx: &mut Context) -> PassResult<(Self::Output, bool)> { + let mut changed = false; + + for func in ctx.funcs() { + let ((), local_changed) = LocalPassMut::run(self, ctx, func)?; + changed |= local_changed; + } + + Ok(((), changed)) + } +} + +impl TransformPass for AggressiveInstcombine { + fn register(passman: &mut crate::ir::passman::PassManager) { + let pass = Self::default(); + passman.register_transform(AGGRESSIVE_INSTCOMBINE, pass, Vec::new()); + } +} + /// Move constant to the right hand side. /// /// This applies to commutative instructions, when the lhs is a constant and the @@ -565,89 +687,49 @@ const fn assoc_sub_zero() -> Rule { } } -// ---------------------------- AGGRESSIVE RULES ---------------------------- // -// TODO: We might need to consider the overflow. - -/// Eliminate subtraction same operand. -/// -/// - `x - x => 0` -/// - `x - (x + y) => 0 - y` -/// - `x - (y + x) => 0 - y` -/// - `(x + y) - x => y` -/// - `(y + x) - x => y` -const fn sub_identity_to_zero() -> Rule { +// Eliminate division by one +const fn div_one_elim() -> Rule { Rule { rewriter: |ctx, inst| { - if let Ik::IBinary(IBinaryOp::Sub) = inst.kind(ctx) { + if let Ik::IBinary(IBinaryOp::SDiv) = inst.kind(ctx) { let lhs = inst.operand(ctx, 0); let rhs = inst.operand(ctx, 1); let dst = inst.result(ctx, 0); - let bitwidth = dst.ty(ctx).bitwidth(ctx); - - if lhs == rhs { - let zero = Inst::iconst( - ctx, - IntConstant::zero(dst.ty(ctx).bitwidth(ctx) as u8), - dst.ty(ctx), - ); - inst.insert_after(ctx, zero); - for user in dst.users(ctx) { - user.replace(ctx, dst, zero.result(ctx, 0)); - } - return true; - } - if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { - if let Ik::IBinary(IBinaryOp::Add) = rhs_inst.kind(ctx) { - let rhs_lhs = rhs_inst.operand(ctx, 0); - let rhs_rhs = rhs_inst.operand(ctx, 1); - if lhs == rhs_lhs { - // x - (x + y) => 0 - y - let zero = - Inst::iconst(ctx, IntConstant::zero(bitwidth as u8), dst.ty(ctx)); - let neg_rhs = - Inst::ibinary(ctx, IBinaryOp::Sub, zero.result(ctx, 0), rhs_rhs); - - inst.insert_after(ctx, zero); - zero.insert_after(ctx, neg_rhs); - - for user in dst.users(ctx) { - user.replace(ctx, dst, neg_rhs.result(ctx, 0)); - } - return true; - } else if lhs == rhs_rhs { - // x - (y + x) => 0 - y - let zero = - Inst::iconst(ctx, IntConstant::zero(bitwidth as u8), dst.ty(ctx)); - let neg_rhs = - Inst::ibinary(ctx, IBinaryOp::Sub, zero.result(ctx, 0), rhs_lhs); - - inst.insert_after(ctx, zero); - zero.insert_after(ctx, neg_rhs); - + if let Ik::IConst(v) = rhs_inst.kind(ctx) { + if v.is_one() { for user in dst.users(ctx) { - user.replace(ctx, dst, neg_rhs.result(ctx, 0)); + user.replace(ctx, dst, lhs); } return true; } } } + } + false + }, + } +} - if let ValueKind::InstResult { inst: lhs_inst, .. } = lhs.kind(ctx) { - if let Ik::IBinary(IBinaryOp::Add) = lhs_inst.kind(ctx) { - let lhs_lhs = lhs_inst.operand(ctx, 0); - let lhs_rhs = lhs_inst.operand(ctx, 1); - if rhs == lhs_lhs { - // (x + y) - x => y - for user in dst.users(ctx) { - user.replace(ctx, dst, lhs_rhs); - } - return true; - } else if rhs == lhs_rhs { - // (y + x) - x => y +// Eliminate division by negative one +const fn div_neg_one_elim() -> Rule { + Rule { + rewriter: |ctx, inst| { + if let Ik::IBinary(IBinaryOp::SDiv) = inst.kind(ctx) { + let lhs = inst.operand(ctx, 0); + let rhs = inst.operand(ctx, 1); + let dst = inst.result(ctx, 0); + + if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { + if let Ik::IConst(v) = rhs_inst.kind(ctx) { + if v.as_signed() == -1 { + let zero = Inst::iconst(ctx, 0, dst.ty(ctx)); + let neg = Inst::ibinary(ctx, IBinaryOp::Sub, zero.result(ctx, 0), lhs); + inst.insert_after(ctx, zero); + zero.insert_after(ctx, neg); for user in dst.users(ctx) { - user.replace(ctx, dst, lhs_lhs); + user.replace(ctx, dst, neg.result(ctx, 0)); } return true; } @@ -659,42 +741,196 @@ const fn sub_identity_to_zero() -> Rule { } } -/// Associativity to combine constants. -/// -/// Add: -/// -/// - `(x + c) + d => x + (c + d)` -/// - `(x - c) + d => x - (c - d)` -/// - `d + (x + c) => x + (c + d)` -/// - `d + (x - c) => x + (d - c)` -/// -/// Sub: -/// -/// - `(x + c) - d => x + (c - d)` -/// - `(x - c) - d => x - (c + d)` -/// - `d - (x + c)`: NOT COMBINABLE, we need an additional `0 - x`, not -/// profitable -/// - `d - (c - x) => x + (d - c)` -/// -/// Mul: -/// -/// - `(x * c) * d => x * (c * d)` -/// - `d * (x * c) => x * (c * d)` -/// -/// The result can be constant folded, we don't need to fold the constant here, -/// we just create new instructions and replace def-use. However, we should -/// check constants to make the transformation profitable. -/// -/// We should also note the commutative property of addition and multiplication. -/// -/// TODO: Can integer division be combined? Signed or unsigned? -const fn assoc_const() -> Rule { - use IBinaryOp as Op; +/// Replace division with shift (and add). +const fn div_to_shift() -> Rule { + Rule { + rewriter: |ctx, inst| { + if let Ik::IBinary(IBinaryOp::SDiv) = inst.kind(ctx) { + let lhs = inst.operand(ctx, 0); + let rhs = inst.operand(ctx, 1); + let dst = inst.result(ctx, 0); + + if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { + if let Ik::IConst(mut v) = rhs_inst.kind(ctx) { + let is_v_neg = if v.as_signed() < 0 { + v = IntConstant::from(-v.as_signed()); + true + } else { + false + }; + if v.is_power_of_two() { + let k = v.trailing_zeros(); + if k == 0 { + return false; + } + let bitwidth = lhs.ty(ctx).bitwidth(ctx) as u8; + let shamt_ks1 = IntConstant::from(k - 1); + let shamt_wsk = IntConstant::from(bitwidth as u32 - k); + let shamt_k: IntConstant = IntConstant::from(k); + + let temp0 = Inst::iconst(ctx, shamt_ks1, lhs.ty(ctx)); + let temp1 = + Inst::ibinary(ctx, IBinaryOp::AShr, lhs, temp0.result(ctx, 0)); + let temp2 = Inst::iconst(ctx, shamt_wsk, lhs.ty(ctx)); + let temp3 = Inst::ibinary( + ctx, + IBinaryOp::LShr, + temp1.result(ctx, 0), + temp2.result(ctx, 0), + ); + let temp4 = + Inst::ibinary(ctx, IBinaryOp::Add, lhs, temp3.result(ctx, 0)); + let temp5 = Inst::iconst(ctx, shamt_k, dst.ty(ctx)); + let final_inst = Inst::ibinary( + ctx, + IBinaryOp::AShr, + temp4.result(ctx, 0), + temp5.result(ctx, 0), + ); + + inst.insert_after(ctx, temp0); + temp0.insert_after(ctx, temp1); + temp1.insert_after(ctx, temp2); + temp2.insert_after(ctx, temp3); + temp3.insert_after(ctx, temp4); + temp4.insert_after(ctx, temp5); + temp5.insert_after(ctx, final_inst); + let dst_new = if is_v_neg { + let i_zero = Inst::iconst(ctx, 0, dst.ty(ctx)); + let i_neg = Inst::ibinary( + ctx, + IBinaryOp::Sub, + i_zero.result(ctx, 0), + final_inst.result(ctx, 0), + ); + final_inst.insert_after(ctx, i_zero); + i_zero.insert_after(ctx, i_neg); + i_neg.result(ctx, 0) + } else { + final_inst.result(ctx, 0) + }; + + for user in dst.users(ctx) { + user.replace(ctx, dst, dst_new); + } + return true; + } + } + } + } + false + }, + } +} + +// Eliminate modulo by (negative) one +const fn rem_one_elim() -> Rule { + Rule { + rewriter: |ctx, inst| { + if let Ik::IBinary(IBinaryOp::SRem) = inst.kind(ctx) { + let rhs = inst.operand(ctx, 1); + let dst = inst.result(ctx, 0); + + if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { + if let Ik::IConst(mut v) = rhs_inst.kind(ctx) { + if v.as_signed() < 0 { + v = IntConstant::from(-v.as_signed()) + } + if v.is_one() { + let zero = Inst::iconst(ctx, 0, dst.ty(ctx)); + inst.insert_after(ctx, zero); + + for user in dst.users(ctx) { + user.replace(ctx, dst, zero.result(ctx, 0)); + } + return true; + } + } + } + } + false + }, + } +} + +/// Replace modulo with shift (and sub). +const fn rem_to_shift() -> Rule { + Rule { + rewriter: |ctx, inst| { + if let Ik::IBinary(IBinaryOp::SRem) = inst.kind(ctx) { + let lhs = inst.operand(ctx, 0); + let rhs = inst.operand(ctx, 1); + let dst = inst.result(ctx, 0); + + if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { + if let Ik::IConst(mut v) = rhs_inst.kind(ctx) { + if v.as_signed() < 0 { + v = IntConstant::from(-v.as_signed()) + } + if v.is_power_of_two() { + let k = v.trailing_zeros(); + if k == 0 { + return false; + } + let bitwidth = lhs.ty(ctx).bitwidth(ctx) as u8; + let shamt_wsk = IntConstant::from(bitwidth as u32 - k); + let andwith = IntConstant::from(v.as_signed() as u32 - 1); + + let tmp0 = + Inst::iconst(ctx, IntConstant::from(bitwidth - 1), lhs.ty(ctx)); + let tmp1 = + Inst::ibinary(ctx, IBinaryOp::AShr, lhs, tmp0.result(ctx, 0)); + let tmp2 = Inst::iconst(ctx, shamt_wsk, dst.ty(ctx)); + let tmp3 = Inst::ibinary( + ctx, + IBinaryOp::LShr, + tmp1.result(ctx, 0), + tmp2.result(ctx, 0), + ); + let tmp4 = Inst::ibinary(ctx, IBinaryOp::Add, lhs, tmp3.result(ctx, 0)); + let tmp5 = Inst::iconst(ctx, andwith, dst.ty(ctx)); + let tmp6 = Inst::ibinary( + ctx, + IBinaryOp::And, + tmp4.result(ctx, 0), + tmp5.result(ctx, 0), + ); + let final_inst = Inst::ibinary( + ctx, + IBinaryOp::Sub, + tmp6.result(ctx, 0), + tmp3.result(ctx, 0), + ); + + inst.insert_after(ctx, tmp0); + tmp0.insert_after(ctx, tmp1); + tmp1.insert_after(ctx, tmp2); + tmp2.insert_after(ctx, tmp3); + tmp3.insert_after(ctx, tmp4); + tmp4.insert_after(ctx, tmp5); + tmp5.insert_after(ctx, tmp6); + tmp6.insert_after(ctx, final_inst); + + for user in dst.users(ctx) { + user.replace(ctx, dst, final_inst.result(ctx, 0)); + } + return true; + } + } + } + } + false + }, + } +} +const fn div_rem_to_mul() -> Rule { Rule { rewriter: |ctx, inst| { if let Ik::IBinary(op) = inst.kind(ctx) { - if !matches!(op, Op::Add | Op::Sub | Op::Mul) { + let is_div = matches!(op, IBinaryOp::SDiv); + let is_rem = matches!(op, IBinaryOp::SRem); + if !is_div && !is_rem { return false; } @@ -702,422 +938,190 @@ const fn assoc_const() -> Rule { let rhs = inst.operand(ctx, 1); let dst = inst.result(ctx, 0); - // consider the situation when lhs is a binary - if let ValueKind::InstResult { inst: lhs_inst, .. } = lhs.kind(ctx) { - if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { - match (lhs_inst.kind(ctx), rhs_inst.kind(ctx)) { - (Ik::IBinary(lhs_op), Ik::IConst(_)) => { - // `(x + c) + d => x + (c + d)` - // `(x - c) + d => x - (c - d)` - // `(x + c) - d => x + (c - d)` - // `(x - c) - d => x - (c + d)` - // `(x * c) * d => x * (c * d)` - - let lhs_lhs = lhs_inst.operand(ctx, 0); - let lhs_rhs = lhs_inst.operand(ctx, 1); + if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { + if let Ik::IConst(mut v) = rhs_inst.kind(ctx) { + if v.is_zero() { + return false; + } + let is_v_neg = if v.as_signed() < 0 { + v = IntConstant::from(-v.as_signed()); + true + } else { + false + }; + if !v.is_power_of_two() { + let bitwidth = lhs.ty(ctx).bitwidth(ctx); + if bitwidth != 32 { + return false; + } + // TODO: + // (magi, disp) = magic(rhs); + // bitwidth = 32 + // + // // mulh v1= lhs, magi + // let v2 = i64 (Ty::int(64)) + // + // srai v2= v1, (disp - bitwidth) + // srli v3= lhs, (bitwidth - 1) + // add ans= v2, v3 + // + let int64 = Ty::int(ctx, 64); + let (magi, disp) = magic(bitwidth as u64, v.as_signed() as u64); - match (lhs_op, op) { - (Op::Add, Op::Add) => { - if let ValueKind::InstResult { - inst: lhs_lhs_inst, .. - } = lhs_lhs.kind(ctx) - { - if let Ik::IConst(_) = lhs_lhs_inst.kind(ctx) { - // (c + x) + d => x + (c + d) - let new_rhs = Inst::ibinary( - ctx, - IBinaryOp::Add, - lhs_lhs, - rhs, - ); - let new_add = Inst::ibinary( - ctx, - IBinaryOp::Add, - lhs_rhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_add); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_add.result(ctx, 0)); - } - return true; - } - } - if let ValueKind::InstResult { - inst: lhs_rhs_inst, .. - } = lhs_rhs.kind(ctx) - { - if let Ik::IConst(_) = lhs_rhs_inst.kind(ctx) { - // (x + c) + d => x + (c + d) - let new_rhs = Inst::ibinary( - ctx, - IBinaryOp::Add, - lhs_rhs, - rhs, - ); - let new_add = Inst::ibinary( - ctx, - IBinaryOp::Add, - lhs_lhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_add); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_add.result(ctx, 0)); - } - return true; - } - } - } - (Op::Sub, Op::Add) => { - if let ValueKind::InstResult { - inst: lhs_rhs_inst, .. - } = lhs_rhs.kind(ctx) - { - if let Ik::IConst(_) = lhs_rhs_inst.kind(ctx) { - // (x - c) + d => x - (c - d) - let new_rhs = Inst::ibinary( - ctx, - IBinaryOp::Sub, - lhs_rhs, - rhs, - ); - let new_sub = Inst::ibinary( - ctx, - IBinaryOp::Sub, - lhs_lhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_sub); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_sub.result(ctx, 0)); - } - return true; - } - } - } - (Op::Add, Op::Sub) => { - if let ValueKind::InstResult { - inst: lhs_lhs_inst, .. - } = lhs_lhs.kind(ctx) - { - if let Ik::IConst(_) = lhs_lhs_inst.kind(ctx) { - // (c + x) - d => x + (c - d) - let new_rhs = Inst::ibinary( - ctx, - IBinaryOp::Sub, - lhs_lhs, - rhs, - ); - let new_add = Inst::ibinary( - ctx, - IBinaryOp::Add, - lhs_rhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_add); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_add.result(ctx, 0)); - } - return true; - } - } - if let ValueKind::InstResult { - inst: lhs_rhs_inst, .. - } = lhs_rhs.kind(ctx) - { - if let Ik::IConst(_) = lhs_rhs_inst.kind(ctx) { - // (x + c) - d => x + (c - d) - let new_rhs = Inst::ibinary( - ctx, - IBinaryOp::Sub, - lhs_rhs, - rhs, - ); - let new_add = Inst::ibinary( - ctx, - IBinaryOp::Add, - lhs_lhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_add); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_add.result(ctx, 0)); - } - return true; - } - } - } - (Op::Sub, Op::Sub) => { - if let ValueKind::InstResult { - inst: lhs_rhs_inst, .. - } = lhs_rhs.kind(ctx) - { - if let Ik::IConst(_) = lhs_rhs_inst.kind(ctx) { - // (x - c) - d => x - (c + d) - let new_rhs = Inst::ibinary( - ctx, - IBinaryOp::Add, - lhs_rhs, - rhs, - ); - let new_sub = Inst::ibinary( - ctx, - IBinaryOp::Sub, - lhs_lhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_sub); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_sub.result(ctx, 0)); - } - return true; - } - } - } - (Op::Mul, Op::Mul) => { - if let ValueKind::InstResult { - inst: lhs_lhs_inst, .. - } = lhs_lhs.kind(ctx) - { - if let Ik::IConst(_) = lhs_lhs_inst.kind(ctx) { - // (c * x) * d => x * (c * d) - let new_rhs = Inst::ibinary( - ctx, - IBinaryOp::Mul, - lhs_lhs, - rhs, - ); - let new_mul = Inst::ibinary( - ctx, - IBinaryOp::Mul, - lhs_rhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_mul); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_mul.result(ctx, 0)); - } - return true; - } - } - if let ValueKind::InstResult { - inst: lhs_rhs_inst, .. - } = lhs_rhs.kind(ctx) - { - if let Ik::IConst(_) = lhs_rhs_inst.kind(ctx) { - // (x * c) * d => x * (c * d) - let new_rhs = Inst::ibinary( - ctx, - IBinaryOp::Mul, - lhs_rhs, - rhs, - ); - let new_mul = Inst::ibinary( - ctx, - IBinaryOp::Mul, - lhs_lhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_mul); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_mul.result(ctx, 0)); - } - return true; - } - } - } - _ => {} - } + // temp0-temp5: 使用64位整数的乘法和右移来模拟32位整数与魔数的高位乘法。 + let temp0 = Inst::cast(ctx, CastOp::SExt, lhs, int64); + let temp1 = Inst::iconst(ctx, magi, int64); + let temp2 = Inst::ibinary( + ctx, + IBinaryOp::Mul, + temp0.result(ctx, 0), + temp1.result(ctx, 0), + ); + let temp3 = Inst::iconst(ctx, IntConstant::from(disp), int64); + let temp4 = Inst::ibinary( + ctx, + IBinaryOp::AShr, + temp2.result(ctx, 0), + temp3.result(ctx, 0), + ); + let temp5 = + Inst::cast(ctx, CastOp::Trunc, temp4.result(ctx, 0), lhs.ty(ctx)); + // temp6-temp7: 获得符号位用于修正 + let temp6 = Inst::iconst( + ctx, + IntConstant::from(bitwidth as u64 - 1), + lhs.ty(ctx), + ); + let temp7 = + Inst::ibinary(ctx, IBinaryOp::LShr, lhs, temp6.result(ctx, 0)); + // final_inst: 使用符号位修正结果,得到除法结果。 + let temp8 = Inst::ibinary( + ctx, + IBinaryOp::Add, + temp5.result(ctx, 0), + temp7.result(ctx, 0), + ); + + inst.insert_after(ctx, temp0); + temp0.insert_after(ctx, temp1); + temp1.insert_after(ctx, temp2); + temp2.insert_after(ctx, temp3); + temp3.insert_after(ctx, temp4); + temp4.insert_after(ctx, temp5); + temp5.insert_after(ctx, temp6); + temp6.insert_after(ctx, temp7); + temp7.insert_after(ctx, temp8); + + // 处理取模的情况,并确定final_inst。 + let final_inst = if is_div { + temp8 + } else if is_rem { + let temp9 = Inst::iconst(ctx, v, lhs.ty(ctx)); + let temp10 = Inst::ibinary( + ctx, + IBinaryOp::Mul, + temp8.result(ctx, 0), + temp9.result(ctx, 0), + ); + let temp11 = + Inst::ibinary(ctx, IBinaryOp::Sub, lhs, temp10.result(ctx, 0)); + + temp8.insert_after(ctx, temp9); + temp9.insert_after(ctx, temp10); + temp10.insert_after(ctx, temp11); + + temp11 + } else { + panic!("unreachable") + }; + + // 处理结果的符号,并确定dst_new。 + let dst_new = if is_v_neg { + let i_zero = Inst::iconst(ctx, 0, dst.ty(ctx)); + let i_neg = Inst::ibinary( + ctx, + IBinaryOp::Sub, + i_zero.result(ctx, 0), + final_inst.result(ctx, 0), + ); + final_inst.insert_after(ctx, i_zero); + i_zero.insert_after(ctx, i_neg); + i_neg.result(ctx, 0) + } else { + final_inst.result(ctx, 0) + }; + + for user in dst.users(ctx) { + user.replace(ctx, dst, dst_new); } - (Ik::IConst(_), Ik::IBinary(rhs_op)) => { - // - `d + (x + c) => x + (c + d)` - // - `d + (x - c) => x + (d - c)` - // - `d - (x + c)`: NOT COMBINABLE, we need an additional `0 - x`, - // not profitable - // - `d - (c - x) => x + (d - c)` - // - `d * (x * c) => x * (c * d)` + return true; + } + } + } + } + false + }, + } +} - let rhs_lhs = rhs_inst.operand(ctx, 0); - let rhs_rhs = rhs_inst.operand(ctx, 1); +// 仅用于div_rem_to_mul。 +fn magic(w: u64, d: u64) -> (u64, u64) { + // w = bitwidth + // d = divisor + let nc = (1 << (w - 1)) - (1 << (w - 1)) % d - 1; // FIXME: 93_nested_call.sy div 0 + let mut p = w; + while 1 << p <= nc * (d - (1 << p) % d) { + p += 1; + } + let s = p; + let m = ((1 << p) + d - (1 << p) % d) / d; - match (op, rhs_op) { - (Op::Add, Op::Add) => { - if let ValueKind::InstResult { - inst: rhs_lhs_inst, .. - } = rhs_lhs.kind(ctx) - { - if let Ik::IConst(_) = rhs_lhs_inst.kind(ctx) { - // d + (c + x) => x + (c + d) - let new_rhs = Inst::ibinary( - ctx, - IBinaryOp::Add, - rhs_lhs, - lhs, - ); - let new_add = Inst::ibinary( - ctx, - IBinaryOp::Add, - rhs_rhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_add); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_add.result(ctx, 0)); - } - return true; - } - } - if let ValueKind::InstResult { - inst: rhs_rhs_inst, .. - } = rhs_rhs.kind(ctx) - { - if let Ik::IConst(_) = rhs_rhs_inst.kind(ctx) { - // d + (x + c) => x + (c + d) - let new_rhs = Inst::ibinary( - ctx, - IBinaryOp::Add, - rhs_rhs, - lhs, - ); - let new_add = Inst::ibinary( - ctx, - IBinaryOp::Add, - rhs_lhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_add); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_add.result(ctx, 0)); - } - return true; - } - } - } - (Op::Add, Op::Sub) => { - if let ValueKind::InstResult { - inst: rhs_rhs_inst, .. - } = rhs_rhs.kind(ctx) - { - if let Ik::IConst(_) = rhs_rhs_inst.kind(ctx) { - // d + (x - c) => x + (d - c) - let new_rhs = Inst::ibinary( - ctx, - IBinaryOp::Sub, - lhs, - rhs_rhs, - ); - let new_add = Inst::ibinary( - ctx, - IBinaryOp::Add, - rhs_lhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_add); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_add.result(ctx, 0)); - } - return true; - } - } - } - (Op::Sub, Op::Sub) => { - if let ValueKind::InstResult { - inst: rhs_lhs_inst, .. - } = rhs_lhs.kind(ctx) - { - if let Ik::IConst(_) = rhs_lhs_inst.kind(ctx) { - // d - (c - x) => x + (d - c) - let new_rhs = Inst::ibinary( - ctx, - IBinaryOp::Sub, - lhs, - rhs_lhs, - ); - let new_add = Inst::ibinary( - ctx, - IBinaryOp::Add, - rhs_rhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_add); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_add.result(ctx, 0)); - } - return true; - } - } - } - (Op::Mul, Op::Mul) => { - if let ValueKind::InstResult { - inst: rhs_lhs_inst, .. - } = rhs_lhs.kind(ctx) - { - if let Ik::IConst(_) = rhs_lhs_inst.kind(ctx) { - // d * (c * x) => x * (c * d) - let new_rhs = Inst::ibinary( - ctx, - IBinaryOp::Mul, - rhs_lhs, - lhs, - ); - let new_mul = Inst::ibinary( - ctx, - IBinaryOp::Mul, - rhs_rhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_mul); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_mul.result(ctx, 0)); - } - return true; - } - } - if let ValueKind::InstResult { - inst: rhs_rhs_inst, .. - } = rhs_rhs.kind(ctx) - { - if let Ik::IConst(_) = rhs_rhs_inst.kind(ctx) { - // d * (x * c) => x * (c * d) - let new_rhs = Inst::ibinary( - ctx, - IBinaryOp::Mul, - rhs_rhs, - lhs, - ); - let new_mul = Inst::ibinary( - ctx, - IBinaryOp::Mul, - rhs_lhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_mul); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_mul.result(ctx, 0)); - } - return true; - } - } - } - _ => {} - } + // m = magi(c number) + // s = disp(lacement) + (m, s) +} + +// Eliminate shift by zero +const fn shl_zero_elim() -> Rule { + Rule { + rewriter: |ctx, inst| { + if let Ik::IBinary(IBinaryOp::Shl) = inst.kind(ctx) { + let lhs = inst.operand(ctx, 0); + let rhs = inst.operand(ctx, 1); + let dst = inst.result(ctx, 0); + + if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { + if let Ik::IConst(v) = rhs_inst.kind(ctx) { + if v.is_zero() { + for user in dst.users(ctx) { + user.replace(ctx, dst, lhs); + } + return true; + } + } + } + } + false + }, + } +} + +// Eliminate shift by zero +const fn shr_zero_elim() -> Rule { + Rule { + rewriter: |ctx, inst| { + if let Ik::IBinary(IBinaryOp::LShr | IBinaryOp::AShr) = inst.kind(ctx) { + let lhs = inst.operand(ctx, 0); + let rhs = inst.operand(ctx, 1); + let dst = inst.result(ctx, 0); + + if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { + if let Ik::IConst(v) = rhs_inst.kind(ctx) { + if v.is_zero() { + for user in dst.users(ctx) { + user.replace(ctx, dst, lhs); } - _ => {} + return true; } } } @@ -1127,410 +1131,559 @@ const fn assoc_const() -> Rule { } } -/// Distributive property special case. -/// -/// - `(x * c) + x` => `x * (c + 1)` -/// - `x + (x * c)` => `x * (c + 1)` -/// - `(x * c) - x` => `x * (c - 1) -/// - `x - (x * c)` => `x * (1 - c)` +// ---------------------------- AGGRESSIVE RULES ---------------------------- // +// TODO: We might need to consider the overflow. + +/// Eliminate subtraction same operand. /// -/// c should be a constant, to reduce the complexity. -const fn distributive_one() -> Rule { +/// - `x - x => 0` +/// - `x - (x + y) => 0 - y` +/// - `x - (y + x) => 0 - y` +/// - `(x + y) - x => y` +/// - `(y + x) - x => y` +const fn sub_identity_to_zero() -> Rule { Rule { rewriter: |ctx, inst| { - if let Ik::IBinary(op) = inst.kind(ctx) { - if !matches!(op, IBinaryOp::Add | IBinaryOp::Sub) { - return false; - } - let op = *op; - + if let Ik::IBinary(IBinaryOp::Sub) = inst.kind(ctx) { let lhs = inst.operand(ctx, 0); let rhs = inst.operand(ctx, 1); let dst = inst.result(ctx, 0); let bitwidth = dst.ty(ctx).bitwidth(ctx); - if let ValueKind::InstResult { inst: lhs_inst, .. } = lhs.kind(ctx) { - if let Ik::IBinary(IBinaryOp::Mul) = lhs_inst.kind(ctx) { - let lhs_lhs = lhs_inst.operand(ctx, 0); - let lhs_rhs = lhs_inst.operand(ctx, 1); - - if lhs_lhs == rhs { - // (x * c) +- x => x * (c +- 1) - if let ValueKind::InstResult { - inst: lhs_rhs_inst, .. - } = lhs_rhs.kind(ctx) - { - if let Ik::IConst(_) = lhs_rhs_inst.kind(ctx) { - let one = Inst::iconst( - ctx, - IntConstant::one(bitwidth as u8), - dst.ty(ctx), - ); - let new_rhs = - Inst::ibinary(ctx, op, lhs_rhs, one.result(ctx, 0)); - let new_mul = Inst::ibinary( - ctx, - IBinaryOp::Mul, - lhs_lhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, one); - one.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_mul); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_mul.result(ctx, 0)); - } - - return true; - } - } - } - if lhs_rhs == rhs { - // (c * x) +- x => x * (c +- 1) - if let ValueKind::InstResult { - inst: lhs_lhs_inst, .. - } = lhs_lhs.kind(ctx) - { - if let Ik::IConst(_) = lhs_lhs_inst.kind(ctx) { - let one = Inst::iconst( - ctx, - IntConstant::one(bitwidth as u8), - dst.ty(ctx), - ); - let new_rhs = - Inst::ibinary(ctx, op, lhs_lhs, one.result(ctx, 0)); - let new_mul = Inst::ibinary( - ctx, - IBinaryOp::Mul, - lhs_rhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, one); - one.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_mul); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_mul.result(ctx, 0)); - } - - return true; - } - } - } + if lhs == rhs { + let zero = Inst::iconst( + ctx, + IntConstant::zero(dst.ty(ctx).bitwidth(ctx) as u8), + dst.ty(ctx), + ); + inst.insert_after(ctx, zero); + for user in dst.users(ctx) { + user.replace(ctx, dst, zero.result(ctx, 0)); } + return true; } if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { - if let Ik::IBinary(IBinaryOp::Mul) = rhs_inst.kind(ctx) { + if let Ik::IBinary(IBinaryOp::Add) = rhs_inst.kind(ctx) { let rhs_lhs = rhs_inst.operand(ctx, 0); let rhs_rhs = rhs_inst.operand(ctx, 1); + if lhs == rhs_lhs { + // x - (x + y) => 0 - y + let zero = + Inst::iconst(ctx, IntConstant::zero(bitwidth as u8), dst.ty(ctx)); + let neg_rhs = + Inst::ibinary(ctx, IBinaryOp::Sub, zero.result(ctx, 0), rhs_rhs); - if rhs_lhs == lhs { - // x +- (x * c) => x * (1 +- c) - if let ValueKind::InstResult { - inst: rhs_rhs_inst, .. - } = rhs_rhs.kind(ctx) - { - if let Ik::IConst(_) = rhs_rhs_inst.kind(ctx) { - let one = Inst::iconst( - ctx, - IntConstant::one(bitwidth as u8), - dst.ty(ctx), - ); - let new_rhs = - Inst::ibinary(ctx, op, one.result(ctx, 0), rhs_rhs); - let new_mul = Inst::ibinary( - ctx, - IBinaryOp::Mul, - lhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, one); - one.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_mul); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_mul.result(ctx, 0)); - } + inst.insert_after(ctx, zero); + zero.insert_after(ctx, neg_rhs); - return true; - } + for user in dst.users(ctx) { + user.replace(ctx, dst, neg_rhs.result(ctx, 0)); + } + return true; + } else if lhs == rhs_rhs { + // x - (y + x) => 0 - y + let zero = + Inst::iconst(ctx, IntConstant::zero(bitwidth as u8), dst.ty(ctx)); + let neg_rhs = + Inst::ibinary(ctx, IBinaryOp::Sub, zero.result(ctx, 0), rhs_lhs); + + inst.insert_after(ctx, zero); + zero.insert_after(ctx, neg_rhs); + + for user in dst.users(ctx) { + user.replace(ctx, dst, neg_rhs.result(ctx, 0)); } + return true; } - if rhs_rhs == lhs { - // x +- (c * x) => x * (1 +- c) - if let ValueKind::InstResult { - inst: rhs_lhs_inst, .. - } = rhs_lhs.kind(ctx) - { - if let Ik::IConst(_) = rhs_lhs_inst.kind(ctx) { - let one = Inst::iconst( - ctx, - IntConstant::one(bitwidth as u8), - dst.ty(ctx), - ); - let new_rhs = - Inst::ibinary(ctx, op, one.result(ctx, 0), rhs_lhs); - let new_mul = Inst::ibinary( - ctx, - IBinaryOp::Mul, - lhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, one); - one.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_mul); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_mul.result(ctx, 0)); - } + } + } - return true; - } + if let ValueKind::InstResult { inst: lhs_inst, .. } = lhs.kind(ctx) { + if let Ik::IBinary(IBinaryOp::Add) = lhs_inst.kind(ctx) { + let lhs_lhs = lhs_inst.operand(ctx, 0); + let lhs_rhs = lhs_inst.operand(ctx, 1); + if rhs == lhs_lhs { + // (x + y) - x => y + for user in dst.users(ctx) { + user.replace(ctx, dst, lhs_rhs); + } + return true; + } else if rhs == lhs_rhs { + // (y + x) - x => y + for user in dst.users(ctx) { + user.replace(ctx, dst, lhs_lhs); } + return true; } } } } - false }, } } -/// Distributive property for strength reduction. +/// Associativity to combine constants. /// -/// - `x * y + x * z => x * (y + z)` -/// - `x * y - z * y => (x - z) * y` -/// - `x * y - x * z => x * (y - z)` -/// - `x * y + z * y => (x + z) * y` -/// - `x / y + z / y => (x + z) / y` -/// - `x / y - z / y => (x - z) / y` +/// Add: /// -/// For constants, we should compare the underlying value, but it can actually -/// be done by global value numbering. +/// - `(x + c) + d => x + (c + d)` +/// - `(x - c) + d => x - (c - d)` +/// - `d + (x + c) => x + (c + d)` +/// - `d + (x - c) => x + (d - c)` /// -/// TODO: SDiv is considered, but how about UDiv? -const fn distributive() -> Rule { +/// Sub: +/// +/// - `(x + c) - d => x + (c - d)` +/// - `(x - c) - d => x - (c + d)` +/// - `d - (x + c)`: NOT COMBINABLE, we need an additional `0 - x`, not +/// profitable +/// - `d - (c - x) => x + (d - c)` +/// +/// Mul: +/// +/// - `(x * c) * d => x * (c * d)` +/// - `d * (x * c) => x * (c * d)` +/// +/// The result can be constant folded, we don't need to fold the constant here, +/// we just create new instructions and replace def-use. However, we should +/// check constants to make the transformation profitable. +/// +/// We should also note the commutative property of addition and multiplication. +/// +/// TODO: Can integer division be combined? Signed or unsigned? +const fn assoc_const() -> Rule { use IBinaryOp as Op; Rule { rewriter: |ctx, inst| { if let Ik::IBinary(op) = inst.kind(ctx) { + if !matches!(op, Op::Add | Op::Sub | Op::Mul) { + return false; + } + let lhs = inst.operand(ctx, 0); let rhs = inst.operand(ctx, 1); let dst = inst.result(ctx, 0); - if !matches!(op, Op::Add | Op::Sub) { - return false; - } - - let op = *op; + // consider the situation when lhs is a binary + if let ValueKind::InstResult { inst: lhs_inst, .. } = lhs.kind(ctx) { + if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { + match (lhs_inst.kind(ctx), rhs_inst.kind(ctx)) { + (Ik::IBinary(lhs_op), Ik::IConst(_)) => { + // `(x + c) + d => x + (c + d)` + // `(x - c) + d => x - (c - d)` + // `(x + c) - d => x + (c - d)` + // `(x - c) - d => x - (c + d)` + // `(x * c) * d => x * (c * d)` - if let ( - ValueKind::InstResult { inst: lhs_inst, .. }, - ValueKind::InstResult { inst: rhs_inst, .. }, - ) = (lhs.kind(ctx), rhs.kind(ctx)) - { - if let (Ik::IBinary(lhs_op), Ik::IBinary(rhs_op)) = - (lhs_inst.kind(ctx), rhs_inst.kind(ctx)) - { - if lhs_op == rhs_op && matches!(lhs_op, Op::Mul | Op::SDiv) { - let lhs_lhs = lhs_inst.operand(ctx, 0); - let lhs_rhs = lhs_inst.operand(ctx, 1); - let rhs_lhs = rhs_inst.operand(ctx, 0); - let rhs_rhs = rhs_inst.operand(ctx, 1); + let lhs_lhs = lhs_inst.operand(ctx, 0); + let lhs_rhs = lhs_inst.operand(ctx, 1); - if let Op::Mul = lhs_op { - if lhs_lhs == rhs_lhs { - // x * y + x * z => x * (y + z) - // x * y - x * z => x * (y - z) - let new_rhs = Inst::ibinary(ctx, op, lhs_rhs, rhs_rhs); - let new_mul = Inst::ibinary( - ctx, - Op::Mul, - lhs_lhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_mul); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_mul.result(ctx, 0)); - } - return true; - } else if lhs_lhs == rhs_rhs { - // x * y + z * x => x * (y + z) - // x * y - z * x => x * (y - z) - let new_rhs = Inst::ibinary(ctx, op, lhs_rhs, rhs_lhs); - let new_mul = Inst::ibinary( - ctx, - Op::Mul, - lhs_lhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_mul); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_mul.result(ctx, 0)); + match (lhs_op, op) { + (Op::Add, Op::Add) => { + if let ValueKind::InstResult { + inst: lhs_lhs_inst, .. + } = lhs_lhs.kind(ctx) + { + if let Ik::IConst(_) = lhs_lhs_inst.kind(ctx) { + // (c + x) + d => x + (c + d) + let new_rhs = Inst::ibinary( + ctx, + IBinaryOp::Add, + lhs_lhs, + rhs, + ); + let new_add = Inst::ibinary( + ctx, + IBinaryOp::Add, + lhs_rhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_add); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_add.result(ctx, 0)); + } + return true; + } + } + if let ValueKind::InstResult { + inst: lhs_rhs_inst, .. + } = lhs_rhs.kind(ctx) + { + if let Ik::IConst(_) = lhs_rhs_inst.kind(ctx) { + // (x + c) + d => x + (c + d) + let new_rhs = Inst::ibinary( + ctx, + IBinaryOp::Add, + lhs_rhs, + rhs, + ); + let new_add = Inst::ibinary( + ctx, + IBinaryOp::Add, + lhs_lhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_add); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_add.result(ctx, 0)); + } + return true; + } + } } - return true; - } else if lhs_rhs == rhs_lhs { - // y * x + x * z => x * (y + z) - // y * x - x * z => x * (y - z) - let new_rhs = Inst::ibinary(ctx, op, lhs_lhs, rhs_rhs); - let new_mul = Inst::ibinary( - ctx, - Op::Mul, - lhs_rhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_mul); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_mul.result(ctx, 0)); + (Op::Sub, Op::Add) => { + if let ValueKind::InstResult { + inst: lhs_rhs_inst, .. + } = lhs_rhs.kind(ctx) + { + if let Ik::IConst(_) = lhs_rhs_inst.kind(ctx) { + // (x - c) + d => x - (c - d) + let new_rhs = Inst::ibinary( + ctx, + IBinaryOp::Sub, + lhs_rhs, + rhs, + ); + let new_sub = Inst::ibinary( + ctx, + IBinaryOp::Sub, + lhs_lhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_sub); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_sub.result(ctx, 0)); + } + return true; + } + } } - return true; - } else if lhs_rhs == rhs_rhs { - // y * x + z * x => x * (y + z) - // y * x - z * x => x * (y - z) - let new_rhs = Inst::ibinary(ctx, op, lhs_lhs, rhs_lhs); - let new_mul = Inst::ibinary( - ctx, - Op::Mul, - lhs_rhs, - new_rhs.result(ctx, 0), - ); - inst.insert_after(ctx, new_rhs); - new_rhs.insert_after(ctx, new_mul); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_mul.result(ctx, 0)); + (Op::Add, Op::Sub) => { + if let ValueKind::InstResult { + inst: lhs_lhs_inst, .. + } = lhs_lhs.kind(ctx) + { + if let Ik::IConst(_) = lhs_lhs_inst.kind(ctx) { + // (c + x) - d => x + (c - d) + let new_rhs = Inst::ibinary( + ctx, + IBinaryOp::Sub, + lhs_lhs, + rhs, + ); + let new_add = Inst::ibinary( + ctx, + IBinaryOp::Add, + lhs_rhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_add); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_add.result(ctx, 0)); + } + return true; + } + } + if let ValueKind::InstResult { + inst: lhs_rhs_inst, .. + } = lhs_rhs.kind(ctx) + { + if let Ik::IConst(_) = lhs_rhs_inst.kind(ctx) { + // (x + c) - d => x + (c - d) + let new_rhs = Inst::ibinary( + ctx, + IBinaryOp::Sub, + lhs_rhs, + rhs, + ); + let new_add = Inst::ibinary( + ctx, + IBinaryOp::Add, + lhs_lhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_add); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_add.result(ctx, 0)); + } + return true; + } + } } - return true; - } - } - - if let Op::SDiv = lhs_op { - if lhs_rhs == rhs_rhs { - // x / y + z / y => (x + z) / y - // x / y - z / y => (x - z) / y - let new_lhs = Inst::ibinary(ctx, op, lhs_lhs, rhs_lhs); - let new_div = Inst::ibinary( - ctx, - Op::SDiv, - new_lhs.result(ctx, 0), - lhs_rhs, - ); - inst.insert_after(ctx, new_lhs); - new_lhs.insert_after(ctx, new_div); - for user in dst.users(ctx) { - user.replace(ctx, dst, new_div.result(ctx, 0)); + (Op::Sub, Op::Sub) => { + if let ValueKind::InstResult { + inst: lhs_rhs_inst, .. + } = lhs_rhs.kind(ctx) + { + if let Ik::IConst(_) = lhs_rhs_inst.kind(ctx) { + // (x - c) - d => x - (c + d) + let new_rhs = Inst::ibinary( + ctx, + IBinaryOp::Add, + lhs_rhs, + rhs, + ); + let new_sub = Inst::ibinary( + ctx, + IBinaryOp::Sub, + lhs_lhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_sub); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_sub.result(ctx, 0)); + } + return true; + } + } } - return true; + (Op::Mul, Op::Mul) => { + if let ValueKind::InstResult { + inst: lhs_lhs_inst, .. + } = lhs_lhs.kind(ctx) + { + if let Ik::IConst(_) = lhs_lhs_inst.kind(ctx) { + // (c * x) * d => x * (c * d) + let new_rhs = Inst::ibinary( + ctx, + IBinaryOp::Mul, + lhs_lhs, + rhs, + ); + let new_mul = Inst::ibinary( + ctx, + IBinaryOp::Mul, + lhs_rhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_mul); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_mul.result(ctx, 0)); + } + return true; + } + } + if let ValueKind::InstResult { + inst: lhs_rhs_inst, .. + } = lhs_rhs.kind(ctx) + { + if let Ik::IConst(_) = lhs_rhs_inst.kind(ctx) { + // (x * c) * d => x * (c * d) + let new_rhs = Inst::ibinary( + ctx, + IBinaryOp::Mul, + lhs_rhs, + rhs, + ); + let new_mul = Inst::ibinary( + ctx, + IBinaryOp::Mul, + lhs_lhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_mul); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_mul.result(ctx, 0)); + } + return true; + } + } + } + _ => {} } } - } - } - } - } - false - }, - } -} - -/// Move same instructions' same operands together, -/// instruction that can be moved together must be associative. -/// -/// `(x @ y) @ x` or `x @ (x @ y)` or `x @ (y @ x)` or `(y @ x) @ x` => `(x @ x) -/// @ y` -/// -/// note: this rule better to run after add_to_mul -/// AGGRESSIVE: float may cause accuracy problem -const fn mv_same_together() -> Rule { - Rule { - rewriter: |ctx, inst| { - if inst.is_associative(ctx) { - let lhs = inst.operand(ctx, 0); - let rhs = inst.operand(ctx, 1); - let dst = inst.result(ctx, 0); - let inst_kind = inst.kind(ctx).clone(); - let ty = dst.ty(ctx); + (Ik::IConst(_), Ik::IBinary(rhs_op)) => { + // - `d + (x + c) => x + (c + d)` + // - `d + (x - c) => x + (d - c)` + // - `d - (x + c)`: NOT COMBINABLE, we need an additional `0 - x`, + // not profitable + // - `d - (c - x) => x + (d - c)` + // - `d * (x * c) => x * (c * d)` - let mut insts_to_move = - if let ValueKind::InstResult { inst: lhs_inst, .. } = lhs.kind(ctx) { - if lhs_inst.kind(ctx) == &inst_kind { - let lhs_lhs = lhs_inst.operand(ctx, 0); - let lhs_rhs = lhs_inst.operand(ctx, 1); - if lhs_lhs == rhs { - Some((rhs, lhs_rhs)) - } else if lhs_rhs == rhs { - Some((rhs, lhs_lhs)) - } else { - None - } - } else { - None - } - } else { - None - }; - if insts_to_move.is_none() { - insts_to_move = - if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { - if rhs_inst.kind(ctx) == &inst_kind { let rhs_lhs = rhs_inst.operand(ctx, 0); let rhs_rhs = rhs_inst.operand(ctx, 1); - if rhs_lhs == lhs { - Some((lhs, rhs_rhs)) - } else if rhs_rhs == lhs { - Some((lhs, rhs_lhs)) - } else { - None + + match (op, rhs_op) { + (Op::Add, Op::Add) => { + if let ValueKind::InstResult { + inst: rhs_lhs_inst, .. + } = rhs_lhs.kind(ctx) + { + if let Ik::IConst(_) = rhs_lhs_inst.kind(ctx) { + // d + (c + x) => x + (c + d) + let new_rhs = Inst::ibinary( + ctx, + IBinaryOp::Add, + rhs_lhs, + lhs, + ); + let new_add = Inst::ibinary( + ctx, + IBinaryOp::Add, + rhs_rhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_add); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_add.result(ctx, 0)); + } + return true; + } + } + if let ValueKind::InstResult { + inst: rhs_rhs_inst, .. + } = rhs_rhs.kind(ctx) + { + if let Ik::IConst(_) = rhs_rhs_inst.kind(ctx) { + // d + (x + c) => x + (c + d) + let new_rhs = Inst::ibinary( + ctx, + IBinaryOp::Add, + rhs_rhs, + lhs, + ); + let new_add = Inst::ibinary( + ctx, + IBinaryOp::Add, + rhs_lhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_add); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_add.result(ctx, 0)); + } + return true; + } + } + } + (Op::Add, Op::Sub) => { + if let ValueKind::InstResult { + inst: rhs_rhs_inst, .. + } = rhs_rhs.kind(ctx) + { + if let Ik::IConst(_) = rhs_rhs_inst.kind(ctx) { + // d + (x - c) => x + (d - c) + let new_rhs = Inst::ibinary( + ctx, + IBinaryOp::Sub, + lhs, + rhs_rhs, + ); + let new_add = Inst::ibinary( + ctx, + IBinaryOp::Add, + rhs_lhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_add); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_add.result(ctx, 0)); + } + return true; + } + } + } + (Op::Sub, Op::Sub) => { + if let ValueKind::InstResult { + inst: rhs_lhs_inst, .. + } = rhs_lhs.kind(ctx) + { + if let Ik::IConst(_) = rhs_lhs_inst.kind(ctx) { + // d - (c - x) => x + (d - c) + let new_rhs = Inst::ibinary( + ctx, + IBinaryOp::Sub, + lhs, + rhs_lhs, + ); + let new_add = Inst::ibinary( + ctx, + IBinaryOp::Add, + rhs_rhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_add); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_add.result(ctx, 0)); + } + return true; + } + } + } + (Op::Mul, Op::Mul) => { + if let ValueKind::InstResult { + inst: rhs_lhs_inst, .. + } = rhs_lhs.kind(ctx) + { + if let Ik::IConst(_) = rhs_lhs_inst.kind(ctx) { + // d * (c * x) => x * (c * d) + let new_rhs = Inst::ibinary( + ctx, + IBinaryOp::Mul, + rhs_lhs, + lhs, + ); + let new_mul = Inst::ibinary( + ctx, + IBinaryOp::Mul, + rhs_rhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_mul); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_mul.result(ctx, 0)); + } + return true; + } + } + if let ValueKind::InstResult { + inst: rhs_rhs_inst, .. + } = rhs_rhs.kind(ctx) + { + if let Ik::IConst(_) = rhs_rhs_inst.kind(ctx) { + // d * (x * c) => x * (c * d) + let new_rhs = Inst::ibinary( + ctx, + IBinaryOp::Mul, + rhs_rhs, + lhs, + ); + let new_mul = Inst::ibinary( + ctx, + IBinaryOp::Mul, + rhs_lhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_mul); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_mul.result(ctx, 0)); + } + return true; + } + } + } + _ => {} } - } else { - None - } - } else { - None - }; - } - if let Some((twice, other)) = insts_to_move { - let inst_inner = - Inst::new(ctx, inst_kind.clone(), vec![ty], vec![twice, twice]); - let inst_outer = Inst::new( - ctx, - inst_kind.clone(), - vec![ty], - vec![inst_inner.result(ctx, 0), other], - ); - inst.insert_after(ctx, inst_inner); - inst_inner.insert_after(ctx, inst_outer); - for user in dst.users(ctx) { - user.replace(ctx, dst, inst_outer.result(ctx, 0)); - } - return true; - } else { - return false; - } - } - false - }, - } -} - -// Eliminate division by one -const fn div_one_elim() -> Rule { - Rule { - rewriter: |ctx, inst| { - if let Ik::IBinary(IBinaryOp::SDiv) = inst.kind(ctx) { - let lhs = inst.operand(ctx, 0); - let rhs = inst.operand(ctx, 1); - let dst = inst.result(ctx, 0); - - if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { - if let Ik::IConst(v) = rhs_inst.kind(ctx) { - if v.is_one() { - for user in dst.users(ctx) { - user.replace(ctx, dst, lhs); } - return true; + _ => {} } } } @@ -1540,209 +1693,305 @@ const fn div_one_elim() -> Rule { } } -// Eliminate division by negative one -const fn div_neg_one_elim() -> Rule { +/// Distributive property special case. +/// +/// - `(x * c) + x` => `x * (c + 1)` +/// - `x + (x * c)` => `x * (c + 1)` +/// - `(x * c) - x` => `x * (c - 1) +/// - `x - (x * c)` => `x * (1 - c)` +/// +/// c should be a constant, to reduce the complexity. +const fn distributive_one() -> Rule { Rule { rewriter: |ctx, inst| { - if let Ik::IBinary(IBinaryOp::SDiv) = inst.kind(ctx) { - let lhs = inst.operand(ctx, 0); - let rhs = inst.operand(ctx, 1); - let dst = inst.result(ctx, 0); - - if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { - if let Ik::IConst(v) = rhs_inst.kind(ctx) { - if v.as_signed() == -1 { - let zero = Inst::iconst(ctx, 0, dst.ty(ctx)); - let neg = Inst::ibinary(ctx, IBinaryOp::Sub, zero.result(ctx, 0), lhs); - inst.insert_after(ctx, zero); - zero.insert_after(ctx, neg); - for user in dst.users(ctx) { - user.replace(ctx, dst, neg.result(ctx, 0)); - } - return true; - } - } + if let Ik::IBinary(op) = inst.kind(ctx) { + if !matches!(op, IBinaryOp::Add | IBinaryOp::Sub) { + return false; } - } - false - }, - } -} + let op = *op; -/// Replace division with shift (and add). -const fn div_to_shift() -> Rule { - Rule { - rewriter: |ctx, inst| { - if let Ik::IBinary(IBinaryOp::SDiv) = inst.kind(ctx) { let lhs = inst.operand(ctx, 0); let rhs = inst.operand(ctx, 1); let dst = inst.result(ctx, 0); - if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { - if let Ik::IConst(mut v) = rhs_inst.kind(ctx) { - let is_v_neg = if v.as_signed() < 0 { - v = IntConstant::from(-v.as_signed()); - true - } else { - false - }; - if v.is_power_of_two() { - let k = v.trailing_zeros(); - if k == 0 { - return false; - } - let bitwidth = lhs.ty(ctx).bitwidth(ctx) as u8; - let shamt_ks1 = IntConstant::from(k - 1); - let shamt_wsk = IntConstant::from(bitwidth as u32 - k); - let shamt_k: IntConstant = IntConstant::from(k); + let bitwidth = dst.ty(ctx).bitwidth(ctx); - let temp0 = Inst::iconst(ctx, shamt_ks1, lhs.ty(ctx)); - let temp1 = - Inst::ibinary(ctx, IBinaryOp::AShr, lhs, temp0.result(ctx, 0)); - let temp2 = Inst::iconst(ctx, shamt_wsk, lhs.ty(ctx)); - let temp3 = Inst::ibinary( - ctx, - IBinaryOp::LShr, - temp1.result(ctx, 0), - temp2.result(ctx, 0), - ); - let temp4 = - Inst::ibinary(ctx, IBinaryOp::Add, lhs, temp3.result(ctx, 0)); - let temp5 = Inst::iconst(ctx, shamt_k, dst.ty(ctx)); - let final_inst = Inst::ibinary( - ctx, - IBinaryOp::AShr, - temp4.result(ctx, 0), - temp5.result(ctx, 0), - ); + if let ValueKind::InstResult { inst: lhs_inst, .. } = lhs.kind(ctx) { + if let Ik::IBinary(IBinaryOp::Mul) = lhs_inst.kind(ctx) { + let lhs_lhs = lhs_inst.operand(ctx, 0); + let lhs_rhs = lhs_inst.operand(ctx, 1); - inst.insert_after(ctx, temp0); - temp0.insert_after(ctx, temp1); - temp1.insert_after(ctx, temp2); - temp2.insert_after(ctx, temp3); - temp3.insert_after(ctx, temp4); - temp4.insert_after(ctx, temp5); - temp5.insert_after(ctx, final_inst); - let dst_new = if is_v_neg { - let i_zero = Inst::iconst(ctx, 0, dst.ty(ctx)); - let i_neg = Inst::ibinary( - ctx, - IBinaryOp::Sub, - i_zero.result(ctx, 0), - final_inst.result(ctx, 0), - ); - final_inst.insert_after(ctx, i_zero); - i_zero.insert_after(ctx, i_neg); - i_neg.result(ctx, 0) - } else { - final_inst.result(ctx, 0) - }; + if lhs_lhs == rhs { + // (x * c) +- x => x * (c +- 1) + if let ValueKind::InstResult { + inst: lhs_rhs_inst, .. + } = lhs_rhs.kind(ctx) + { + if let Ik::IConst(_) = lhs_rhs_inst.kind(ctx) { + let one = Inst::iconst( + ctx, + IntConstant::one(bitwidth as u8), + dst.ty(ctx), + ); + let new_rhs = + Inst::ibinary(ctx, op, lhs_rhs, one.result(ctx, 0)); + let new_mul = Inst::ibinary( + ctx, + IBinaryOp::Mul, + lhs_lhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, one); + one.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_mul); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_mul.result(ctx, 0)); + } - for user in dst.users(ctx) { - user.replace(ctx, dst, dst_new); + return true; + } } - return true; } - } - } - } - false - }, - } -} - -// Eliminate modulo by (negative) one -const fn rem_one_elim() -> Rule { - Rule { - rewriter: |ctx, inst| { - if let Ik::IBinary(IBinaryOp::SRem) = inst.kind(ctx) { - let rhs = inst.operand(ctx, 1); - let dst = inst.result(ctx, 0); + if lhs_rhs == rhs { + // (c * x) +- x => x * (c +- 1) + if let ValueKind::InstResult { + inst: lhs_lhs_inst, .. + } = lhs_lhs.kind(ctx) + { + if let Ik::IConst(_) = lhs_lhs_inst.kind(ctx) { + let one = Inst::iconst( + ctx, + IntConstant::one(bitwidth as u8), + dst.ty(ctx), + ); + let new_rhs = + Inst::ibinary(ctx, op, lhs_lhs, one.result(ctx, 0)); + let new_mul = Inst::ibinary( + ctx, + IBinaryOp::Mul, + lhs_rhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, one); + one.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_mul); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_mul.result(ctx, 0)); + } + + return true; + } + } + } + } + } if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { - if let Ik::IConst(mut v) = rhs_inst.kind(ctx) { - if v.as_signed() < 0 { - v = IntConstant::from(-v.as_signed()) + if let Ik::IBinary(IBinaryOp::Mul) = rhs_inst.kind(ctx) { + let rhs_lhs = rhs_inst.operand(ctx, 0); + let rhs_rhs = rhs_inst.operand(ctx, 1); + + if rhs_lhs == lhs { + // x +- (x * c) => x * (1 +- c) + if let ValueKind::InstResult { + inst: rhs_rhs_inst, .. + } = rhs_rhs.kind(ctx) + { + if let Ik::IConst(_) = rhs_rhs_inst.kind(ctx) { + let one = Inst::iconst( + ctx, + IntConstant::one(bitwidth as u8), + dst.ty(ctx), + ); + let new_rhs = + Inst::ibinary(ctx, op, one.result(ctx, 0), rhs_rhs); + let new_mul = Inst::ibinary( + ctx, + IBinaryOp::Mul, + lhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, one); + one.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_mul); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_mul.result(ctx, 0)); + } + + return true; + } + } } - if v.is_one() { - let zero = Inst::iconst(ctx, 0, dst.ty(ctx)); - inst.insert_after(ctx, zero); + if rhs_rhs == lhs { + // x +- (c * x) => x * (1 +- c) + if let ValueKind::InstResult { + inst: rhs_lhs_inst, .. + } = rhs_lhs.kind(ctx) + { + if let Ik::IConst(_) = rhs_lhs_inst.kind(ctx) { + let one = Inst::iconst( + ctx, + IntConstant::one(bitwidth as u8), + dst.ty(ctx), + ); + let new_rhs = + Inst::ibinary(ctx, op, one.result(ctx, 0), rhs_lhs); + let new_mul = Inst::ibinary( + ctx, + IBinaryOp::Mul, + lhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, one); + one.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_mul); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_mul.result(ctx, 0)); + } - for user in dst.users(ctx) { - user.replace(ctx, dst, zero.result(ctx, 0)); + return true; + } } - return true; } } } } + false }, } } -/// Replace modulo with shift (and sub). -const fn rem_to_shift() -> Rule { +/// Distributive property for strength reduction. +/// +/// - `x * y + x * z => x * (y + z)` +/// - `x * y - z * y => (x - z) * y` +/// - `x * y - x * z => x * (y - z)` +/// - `x * y + z * y => (x + z) * y` +/// - `x / y + z / y => (x + z) / y` +/// - `x / y - z / y => (x - z) / y` +/// +/// For constants, we should compare the underlying value, but it can actually +/// be done by global value numbering. +/// +/// TODO: SDiv is considered, but how about UDiv? +const fn distributive() -> Rule { + use IBinaryOp as Op; + Rule { rewriter: |ctx, inst| { - if let Ik::IBinary(IBinaryOp::SRem) = inst.kind(ctx) { + if let Ik::IBinary(op) = inst.kind(ctx) { let lhs = inst.operand(ctx, 0); let rhs = inst.operand(ctx, 1); let dst = inst.result(ctx, 0); - if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { - if let Ik::IConst(mut v) = rhs_inst.kind(ctx) { - if v.as_signed() < 0 { - v = IntConstant::from(-v.as_signed()) - } - if v.is_power_of_two() { - let k = v.trailing_zeros(); - if k == 0 { - return false; - } - let bitwidth = lhs.ty(ctx).bitwidth(ctx) as u8; - let shamt_wsk = IntConstant::from(bitwidth as u32 - k); - let andwith = IntConstant::from(v.as_signed() as u32 - 1); + if !matches!(op, Op::Add | Op::Sub) { + return false; + } - let tmp0 = - Inst::iconst(ctx, IntConstant::from(bitwidth - 1), lhs.ty(ctx)); - let tmp1 = - Inst::ibinary(ctx, IBinaryOp::AShr, lhs, tmp0.result(ctx, 0)); - let tmp2 = Inst::iconst(ctx, shamt_wsk, dst.ty(ctx)); - let tmp3 = Inst::ibinary( - ctx, - IBinaryOp::LShr, - tmp1.result(ctx, 0), - tmp2.result(ctx, 0), - ); - let tmp4 = Inst::ibinary(ctx, IBinaryOp::Add, lhs, tmp3.result(ctx, 0)); - let tmp5 = Inst::iconst(ctx, andwith, dst.ty(ctx)); - let tmp6 = Inst::ibinary( - ctx, - IBinaryOp::And, - tmp4.result(ctx, 0), - tmp5.result(ctx, 0), - ); - let final_inst = Inst::ibinary( - ctx, - IBinaryOp::Sub, - tmp6.result(ctx, 0), - tmp3.result(ctx, 0), - ); + let op = *op; - inst.insert_after(ctx, tmp0); - tmp0.insert_after(ctx, tmp1); - tmp1.insert_after(ctx, tmp2); - tmp2.insert_after(ctx, tmp3); - tmp3.insert_after(ctx, tmp4); - tmp4.insert_after(ctx, tmp5); - tmp5.insert_after(ctx, tmp6); - tmp6.insert_after(ctx, final_inst); + if let ( + ValueKind::InstResult { inst: lhs_inst, .. }, + ValueKind::InstResult { inst: rhs_inst, .. }, + ) = (lhs.kind(ctx), rhs.kind(ctx)) + { + if let (Ik::IBinary(lhs_op), Ik::IBinary(rhs_op)) = + (lhs_inst.kind(ctx), rhs_inst.kind(ctx)) + { + if lhs_op == rhs_op && matches!(lhs_op, Op::Mul | Op::SDiv) { + let lhs_lhs = lhs_inst.operand(ctx, 0); + let lhs_rhs = lhs_inst.operand(ctx, 1); + let rhs_lhs = rhs_inst.operand(ctx, 0); + let rhs_rhs = rhs_inst.operand(ctx, 1); - for user in dst.users(ctx) { - user.replace(ctx, dst, final_inst.result(ctx, 0)); + if let Op::Mul = lhs_op { + if lhs_lhs == rhs_lhs { + // x * y + x * z => x * (y + z) + // x * y - x * z => x * (y - z) + let new_rhs = Inst::ibinary(ctx, op, lhs_rhs, rhs_rhs); + let new_mul = Inst::ibinary( + ctx, + Op::Mul, + lhs_lhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_mul); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_mul.result(ctx, 0)); + } + return true; + } else if lhs_lhs == rhs_rhs { + // x * y + z * x => x * (y + z) + // x * y - z * x => x * (y - z) + let new_rhs = Inst::ibinary(ctx, op, lhs_rhs, rhs_lhs); + let new_mul = Inst::ibinary( + ctx, + Op::Mul, + lhs_lhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_mul); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_mul.result(ctx, 0)); + } + return true; + } else if lhs_rhs == rhs_lhs { + // y * x + x * z => x * (y + z) + // y * x - x * z => x * (y - z) + let new_rhs = Inst::ibinary(ctx, op, lhs_lhs, rhs_rhs); + let new_mul = Inst::ibinary( + ctx, + Op::Mul, + lhs_rhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_mul); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_mul.result(ctx, 0)); + } + return true; + } else if lhs_rhs == rhs_rhs { + // y * x + z * x => x * (y + z) + // y * x - z * x => x * (y - z) + let new_rhs = Inst::ibinary(ctx, op, lhs_lhs, rhs_lhs); + let new_mul = Inst::ibinary( + ctx, + Op::Mul, + lhs_rhs, + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_mul); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_mul.result(ctx, 0)); + } + return true; + } + } + + if let Op::SDiv = lhs_op { + if lhs_rhs == rhs_rhs { + // x / y + z / y => (x + z) / y + // x / y - z / y => (x - z) / y + let new_lhs = Inst::ibinary(ctx, op, lhs_lhs, rhs_lhs); + let new_div = Inst::ibinary( + ctx, + Op::SDiv, + new_lhs.result(ctx, 0), + lhs_rhs, + ); + inst.insert_after(ctx, new_lhs); + new_lhs.insert_after(ctx, new_div); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_div.result(ctx, 0)); + } + return true; + } } - return true; } } } @@ -1752,139 +2001,78 @@ const fn rem_to_shift() -> Rule { } } -const fn div_rem_to_mul() -> Rule { +/// Move same instructions' same operands together, +/// instruction that can be moved together must be associative. +/// +/// `(x @ y) @ x` or `x @ (x @ y)` or `x @ (y @ x)` or `(y @ x) @ x` => `(x @ x) +/// @ y` +/// +/// note: this rule better to run after add_to_mul +const fn mv_same_together() -> Rule { Rule { rewriter: |ctx, inst| { - if let Ik::IBinary(op) = inst.kind(ctx) { - let is_div = matches!(op, IBinaryOp::SDiv); - let is_rem = matches!(op, IBinaryOp::SRem); - if !is_div && !is_rem { - return false; - } - + if inst.is_associative(ctx) { let lhs = inst.operand(ctx, 0); let rhs = inst.operand(ctx, 1); let dst = inst.result(ctx, 0); + let inst_kind = inst.kind(ctx).clone(); + let ty = dst.ty(ctx); - if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { - if let Ik::IConst(mut v) = rhs_inst.kind(ctx) { - if v.is_zero() { - return false; - } - let is_v_neg = if v.as_signed() < 0 { - v = IntConstant::from(-v.as_signed()); - true - } else { - false - }; - if !v.is_power_of_two() { - let bitwidth = lhs.ty(ctx).bitwidth(ctx); - if bitwidth != 32 { - return false; - } - // TODO: - // (magi, disp) = magic(rhs); - // bitwidth = 32 - // - // // mulh v1= lhs, magi - // let v2 = i64 (Ty::int(64)) - // - // srai v2= v1, (disp - bitwidth) - // srli v3= lhs, (bitwidth - 1) - // add ans= v2, v3 - // - let int64 = Ty::int(ctx, 64); - let (magi, disp) = magic(bitwidth as u64, v.as_signed() as u64); - - // temp0-temp5: 使用64位整数的乘法和右移来模拟32位整数与魔数的高位乘法。 - let temp0 = Inst::cast(ctx, CastOp::SExt, lhs, int64); - let temp1 = Inst::iconst(ctx, magi, int64); - let temp2 = Inst::ibinary( - ctx, - IBinaryOp::Mul, - temp0.result(ctx, 0), - temp1.result(ctx, 0), - ); - let temp3 = Inst::iconst(ctx, IntConstant::from(disp), int64); - let temp4 = Inst::ibinary( - ctx, - IBinaryOp::AShr, - temp2.result(ctx, 0), - temp3.result(ctx, 0), - ); - let temp5 = - Inst::cast(ctx, CastOp::Trunc, temp4.result(ctx, 0), lhs.ty(ctx)); - // temp6-temp7: 获得符号位用于修正 - let temp6 = Inst::iconst( - ctx, - IntConstant::from(bitwidth as u64 - 1), - lhs.ty(ctx), - ); - let temp7 = - Inst::ibinary(ctx, IBinaryOp::LShr, lhs, temp6.result(ctx, 0)); - // final_inst: 使用符号位修正结果,得到除法结果。 - let temp8 = Inst::ibinary( - ctx, - IBinaryOp::Add, - temp5.result(ctx, 0), - temp7.result(ctx, 0), - ); - - inst.insert_after(ctx, temp0); - temp0.insert_after(ctx, temp1); - temp1.insert_after(ctx, temp2); - temp2.insert_after(ctx, temp3); - temp3.insert_after(ctx, temp4); - temp4.insert_after(ctx, temp5); - temp5.insert_after(ctx, temp6); - temp6.insert_after(ctx, temp7); - temp7.insert_after(ctx, temp8); - - // 处理取模的情况,并确定final_inst。 - let final_inst = if is_div { - temp8 - } else if is_rem { - let temp9 = Inst::iconst(ctx, v, lhs.ty(ctx)); - let temp10 = Inst::ibinary( - ctx, - IBinaryOp::Mul, - temp8.result(ctx, 0), - temp9.result(ctx, 0), - ); - let temp11 = - Inst::ibinary(ctx, IBinaryOp::Sub, lhs, temp10.result(ctx, 0)); - - temp8.insert_after(ctx, temp9); - temp9.insert_after(ctx, temp10); - temp10.insert_after(ctx, temp11); - - temp11 - } else { - panic!("unreachable") - }; - - // 处理结果的符号,并确定dst_new。 - let dst_new = if is_v_neg { - let i_zero = Inst::iconst(ctx, 0, dst.ty(ctx)); - let i_neg = Inst::ibinary( - ctx, - IBinaryOp::Sub, - i_zero.result(ctx, 0), - final_inst.result(ctx, 0), - ); - final_inst.insert_after(ctx, i_zero); - i_zero.insert_after(ctx, i_neg); - i_neg.result(ctx, 0) + let mut insts_to_move = + if let ValueKind::InstResult { inst: lhs_inst, .. } = lhs.kind(ctx) { + if lhs_inst.kind(ctx) == &inst_kind { + let lhs_lhs = lhs_inst.operand(ctx, 0); + let lhs_rhs = lhs_inst.operand(ctx, 1); + if lhs_lhs == rhs { + Some((rhs, lhs_rhs)) + } else if lhs_rhs == rhs { + Some((rhs, lhs_lhs)) } else { - final_inst.result(ctx, 0) - }; - - for user in dst.users(ctx) { - user.replace(ctx, dst, dst_new); + None } - return true; + } else { + None } + } else { + None + }; + if insts_to_move.is_none() { + insts_to_move = + if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { + if rhs_inst.kind(ctx) == &inst_kind { + let rhs_lhs = rhs_inst.operand(ctx, 0); + let rhs_rhs = rhs_inst.operand(ctx, 1); + if rhs_lhs == lhs { + Some((lhs, rhs_rhs)) + } else if rhs_rhs == lhs { + Some((lhs, rhs_lhs)) + } else { + None + } + } else { + None + } + } else { + None + }; + } + if let Some((twice, other)) = insts_to_move { + let inst_inner = + Inst::new(ctx, inst_kind.clone(), vec![ty], vec![twice, twice]); + let inst_outer = Inst::new( + ctx, + inst_kind.clone(), + vec![ty], + vec![inst_inner.result(ctx, 0), other], + ); + inst.insert_after(ctx, inst_inner); + inst_inner.insert_after(ctx, inst_outer); + for user in dst.users(ctx) { + user.replace(ctx, dst, inst_outer.result(ctx, 0)); } + return true; + } else { + return false; } } false @@ -1892,39 +2080,55 @@ const fn div_rem_to_mul() -> Rule { } } -// 仅用于div_rem_to_mul。 -fn magic(w: u64, d: u64) -> (u64, u64) { - // w = bitwidth - // d = divisor - let nc = (1 << (w - 1)) - (1 << (w - 1)) % d - 1; // FIXME: 93_nested_call.sy div 0 - let mut p = w; - while 1 << p <= nc * (d - (1 << p) % d) { - p += 1; - } - let s = p; - let m = ((1 << p) + d - (1 << p) % d) / d; - - // m = magi(c number) - // s = disp(lacement) - (m, s) -} - -// Eliminate shift by zero -const fn shl_zero_elim() -> Rule { +/// Distribute add/sub & mul with constants, create more opportunities for +/// induction variables. +/// +/// `(x +- c1) * c2` => x * c2 +- c1 * c2 +const fn redistribute_const() -> Rule { Rule { rewriter: |ctx, inst| { - if let Ik::IBinary(IBinaryOp::Shl) = inst.kind(ctx) { + if let Ik::IBinary(IBinaryOp::Mul) = inst.kind(ctx) { let lhs = inst.operand(ctx, 0); let rhs = inst.operand(ctx, 1); let dst = inst.result(ctx, 0); - if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { - if let Ik::IConst(v) = rhs_inst.kind(ctx) { - if v.is_zero() { - for user in dst.users(ctx) { - user.replace(ctx, dst, lhs); + if let ValueKind::InstResult { inst: lhs_inst, .. } = lhs.kind(ctx) { + if let Ik::IBinary(lhs_op) = lhs_inst.kind(ctx) { + let lhs_lhs = lhs_inst.operand(ctx, 0); + let lhs_rhs = lhs_inst.operand(ctx, 1); + let lhs_op = *lhs_op; + + if lhs_op != IBinaryOp::Add && lhs_op != IBinaryOp::Sub { + return false; + } + + if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { + if let Ik::IConst(_) = rhs_inst.kind(ctx) { + if let ValueKind::InstResult { + inst: lhs_rhs_inst, .. + } = lhs_rhs.kind(ctx) + { + if let Ik::IConst(_) = lhs_rhs_inst.kind(ctx) { + let new_lhs = + Inst::ibinary(ctx, IBinaryOp::Mul, lhs_lhs, rhs); + let new_rhs = + Inst::ibinary(ctx, IBinaryOp::Mul, lhs_rhs, rhs); + let new_add = Inst::ibinary( + ctx, + lhs_op, + new_lhs.result(ctx, 0), + new_rhs.result(ctx, 0), + ); + inst.insert_after(ctx, new_lhs); + new_lhs.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_add); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_add.result(ctx, 0)); + } + return true; + } + } } - return true; } } } @@ -1934,22 +2138,121 @@ const fn shl_zero_elim() -> Rule { } } -// Eliminate shift by zero -const fn shr_zero_elim() -> Rule { +/// Associate more. +/// +/// (x +- c1) +- (y +- c2) => (x +- y) +- (c1 +- c2) +const fn reassociate() -> Rule { Rule { rewriter: |ctx, inst| { - if let Ik::IBinary(IBinaryOp::LShr | IBinaryOp::AShr) = inst.kind(ctx) { + if let Ik::IBinary(op) = inst.kind(ctx) { let lhs = inst.operand(ctx, 0); let rhs = inst.operand(ctx, 1); let dst = inst.result(ctx, 0); + let op = *op; - if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { - if let Ik::IConst(v) = rhs_inst.kind(ctx) { - if v.is_zero() { - for user in dst.users(ctx) { - user.replace(ctx, dst, lhs); + if let ValueKind::InstResult { inst: lhs_inst, .. } = lhs.kind(ctx) { + if let Ik::IBinary(lhs_op) = lhs_inst.kind(ctx) { + let lhs_lhs = lhs_inst.operand(ctx, 0); + let lhs_rhs = lhs_inst.operand(ctx, 1); + let lhs_op = *lhs_op; + + if let ValueKind::InstResult { inst: rhs_inst, .. } = rhs.kind(ctx) { + if let Ik::IBinary(rhs_op) = rhs_inst.kind(ctx) { + let rhs_lhs = rhs_inst.operand(ctx, 0); + let rhs_rhs = rhs_inst.operand(ctx, 1); + let rhs_op = *rhs_op; + + if let ValueKind::InstResult { + inst: lhs_rhs_inst, .. + } = lhs_rhs.kind(ctx) + { + if let ValueKind::InstResult { + inst: rhs_rhs_inst, .. + } = rhs_rhs.kind(ctx) + { + if let Ik::IConst(_) = lhs_rhs_inst.kind(ctx) { + if let Ik::IConst(_) = rhs_rhs_inst.kind(ctx) { + // (x + c1) + (y + c2) => (x + y) + (c1 + c2) + // (x + c1) + (y - c2) => (x + y) + (c1 - c2) + // (x - c1) + (y - c2) => (x + y) - (c1 + c2) + // (x - c1) + (y + c2) => (x + y) - (c1 - c2) + // (x + c1) - (y + c2) => (x - y) + (c1 - c2) + // (x + c1) - (y - c2) => (x - y) + (c1 + c2) + // (x - c1) - (y - c2) => (x - y) - (c1 - c2) + // (x - c1) - (y + c2) => (x - y) - (c1 + c2) + + let new_lhs_op = op; + let new_op = lhs_op; + let new_rhs_op = match (lhs_op, op, rhs_op) { + ( + IBinaryOp::Add, + IBinaryOp::Add, + IBinaryOp::Add, + ) => IBinaryOp::Add, + ( + IBinaryOp::Add, + IBinaryOp::Add, + IBinaryOp::Sub, + ) => IBinaryOp::Sub, + ( + IBinaryOp::Sub, + IBinaryOp::Add, + IBinaryOp::Sub, + ) => IBinaryOp::Add, + ( + IBinaryOp::Sub, + IBinaryOp::Add, + IBinaryOp::Add, + ) => IBinaryOp::Sub, + ( + IBinaryOp::Add, + IBinaryOp::Sub, + IBinaryOp::Add, + ) => IBinaryOp::Sub, + ( + IBinaryOp::Add, + IBinaryOp::Sub, + IBinaryOp::Sub, + ) => IBinaryOp::Add, + ( + IBinaryOp::Sub, + IBinaryOp::Sub, + IBinaryOp::Sub, + ) => IBinaryOp::Sub, + ( + IBinaryOp::Sub, + IBinaryOp::Sub, + IBinaryOp::Add, + ) => IBinaryOp::Add, + _ => return false, + }; + + let new_lhs = Inst::ibinary( + ctx, new_lhs_op, lhs_lhs, rhs_lhs, + ); + let new_rhs = Inst::ibinary( + ctx, new_rhs_op, lhs_rhs, rhs_rhs, + ); + let new_inst = Inst::ibinary( + ctx, + new_op, + new_lhs.result(ctx, 0), + new_rhs.result(ctx, 0), + ); + + inst.insert_after(ctx, new_lhs); + new_lhs.insert_after(ctx, new_rhs); + new_rhs.insert_after(ctx, new_inst); + for user in dst.users(ctx) { + user.replace(ctx, dst, new_inst.result(ctx, 0)); + } + + return true; + } + } + } + } } - return true; } } } diff --git a/src/ir/passes/loops/dead_loop_elim.rs b/src/ir/passes/loops/dead_loop_elim.rs index 411fed4..c555630 100644 --- a/src/ir/passes/loops/dead_loop_elim.rs +++ b/src/ir/passes/loops/dead_loop_elim.rs @@ -1,13 +1,15 @@ -use super::{scalar_evolution::LoopScevRecord, Lcssa, LoopSimplify, Scev, ScevAnalysis}; +use super::{scalar_evolution::LoopScevRecord, Lcssa, LoopSimplify, ScevAnalysis}; use crate::{ collections::linked_list::LinkedListContainerPtr, ir::{ - passes::control_flow::CfgCanonicalize, + passes::{ + control_flow::CfgCanonicalize, + loops::scalar_evolution::{LoopBound, LoopBoundCond}, + }, passman::{GlobalPassMut, LocalPass, LocalPassMut, PassResult, TransformPass}, Block, Context, Func, - ICmpCond, Inst, Ty, }, @@ -29,63 +31,46 @@ impl DeadLoopElim { fn process_loop(&mut self, ctx: &mut Context, lp: Loop, scevs: &LoopScevRecord) -> bool { let mut changed = false; - let (lhs, cmp_cond, rhs) = if let Some(bound) = scevs.loop_bounds.get(&lp).unwrap() { + let LoopBound { + block_param: repr, + cond: cmp_cond, + bound, + reversed, + .. + } = if let Some(bound) = scevs.loop_bounds.get(&lp).unwrap() { bound } else { return false; }; - match cmp_cond { - ICmpCond::Eq | ICmpCond::Ne => return false, - // TODO: Support unsigned bounds - ICmpCond::Ule | ICmpCond::Ult => return false, - ICmpCond::Sle | ICmpCond::Slt => {} - } + let preheader = lp.get_preheader(ctx, &self.scev.loops).unwrap(); + let header = lp.header(&self.scev.loops); + let jump = preheader.tail(ctx).unwrap(); - if let Some(Scev { repr, .. }) = scevs - .scevs - .get(&lp) - .unwrap() - .iter() - .find(|s| s.repr == *lhs || s.repr == *rhs) - { - let (bound, cmp_cond, reversed) = if repr == lhs { - (rhs, cmp_cond, false) - } else if repr == rhs { - (lhs, cmp_cond, true) - } else { - unreachable!() - }; - - let preheader = lp.get_preheader(ctx, &self.scev.loop_ctx).unwrap(); - let header = lp.header(&self.scev.loop_ctx); - let jump = preheader.tail(ctx).unwrap(); - - let br = header.tail(ctx).unwrap(); - - assert!(jump.is_jump(ctx)); - - if !br.is_br(ctx) { - return false; - } + let br = header.tail(ctx).unwrap(); + + assert!(jump.is_jump(ctx)); - let initial_arg = jump.succ(ctx, 0).get_arg(*repr).unwrap(); + if !br.is_br(ctx) { + return false; + } - if initial_arg == *bound && *cmp_cond == ICmpCond::Slt && !reversed { - // this is a dead loop, replace the compare result to be always - // false. - let i1 = Ty::int(ctx, 1); - let iconst = Inst::iconst(ctx, false, i1); - header.push_front(ctx, iconst); + let initial_arg = jump.succ(ctx, 0).get_arg(*repr).unwrap(); - let cond = br.operand(ctx, 0); + if initial_arg == *bound && *cmp_cond == LoopBoundCond::Slt && !reversed { + // this is a dead loop, replace the compare result to be always + // false. + let i1 = Ty::int(ctx, 1); + let iconst = Inst::iconst(ctx, false, i1); + header.push_front(ctx, iconst); - for user in cond.users(ctx) { - user.replace(ctx, cond, iconst.result(ctx, 0)); - } + let cond = br.operand(ctx, 0); - changed = true; + for user in cond.users(ctx) { + user.replace(ctx, cond, iconst.result(ctx, 0)); } + + changed = true; } changed @@ -102,8 +87,8 @@ impl LocalPassMut for DeadLoopElim { let mut loops = Vec::new(); - for lp in self.scev.loop_ctx.loops() { - let depth = lp.depth(&self.scev.loop_ctx); + for lp in self.scev.loops.loops() { + let depth = lp.depth(&self.scev.loops); loops.push(LoopWithDepth { lp, depth }); } @@ -121,8 +106,8 @@ impl GlobalPassMut for DeadLoopElim { type Output = (); fn run(&mut self, ctx: &mut Context) -> PassResult<(Self::Output, bool)> { - ctx.alloc_all_names(); - println!("{}", ctx.display(true)); + // ctx.alloc_all_names(); + // println!("{}", ctx.display(true)); let mut changed = false; for func in ctx.funcs() { diff --git a/src/ir/passes/loops/indvar_offset.rs b/src/ir/passes/loops/indvar_offset.rs new file mode 100644 index 0000000..ec28a7a --- /dev/null +++ b/src/ir/passes/loops/indvar_offset.rs @@ -0,0 +1,166 @@ +use super::{scalar_evolution::LoopScevRecord, InductionOp, Lcssa, LoopSimplify, ScevAnalysis}; +use crate::{ + collections::linked_list::{LinkedListContainerPtr, LinkedListNodePtr}, + ir::{ + passes::{control_flow::CfgCanonicalize, loops::Scev, simple_dce::SimpleDce}, + passman::{GlobalPassMut, LocalPass, LocalPassMut, PassManager, PassResult, TransformPass}, + Block, + Context, + Func, + IBinaryOp, + Inst, + InstKind, + }, + utils::{ + def_use::{Usable, User}, + loop_info::{Loop, LoopWithDepth}, + }, +}; + +pub const INDVAR_OFFSET: &str = "indvar-offset"; + +#[derive(Default)] +pub struct IndvarOffset { + scev: ScevAnalysis, +} + +impl IndvarOffset { + fn process_loop(&self, ctx: &mut Context, lp: Loop, scevs: &LoopScevRecord) -> bool { + let preheader = lp.get_preheader(ctx, &self.scev.loops).unwrap(); + let header = lp.header(&self.scev.loops); + + let blocks = lp.get_blocks(ctx, &self.scev.loops); + + if blocks.len() != 2 { + return false; + } + + let body = if blocks[0] == header { + blocks[1] + } else { + blocks[0] + }; + + let mut cursor = body.cursor(); + + let mut changed = false; + + while let Some(inst) = cursor.next(ctx) { + if let InstKind::Offset = inst.kind(ctx) { + let ptr = inst.operand(ctx, 0); + let offset = inst.operand(ctx, 1); + + if let Some(Scev { init, step, op, .. }) = scevs.scevs.get(&offset) { + let new_step = match op { + InductionOp::Add => *step, + InductionOp::Sub => { + // negate the step, step is a loop invariant, sub it with 0 in the + // preheader + let zero = Inst::iconst(ctx, 0, step.ty(ctx)); + let new_step = + Inst::ibinary(ctx, IBinaryOp::Sub, zero.result(ctx, 0), *step); + preheader.push_inst_before_terminator(ctx, zero); + preheader.push_inst_before_terminator(ctx, new_step); + new_step.result(ctx, 0) + } + InductionOp::Mul | InductionOp::SDiv | InductionOp::Shl => return false, + }; + + // the offset is a indvar, we can just induce the pointer, + // and maybe the indvar can be removed later. + + // firstly, create a new block parameter for the pointer + let new_ptr = header.new_param(ctx, ptr.ty(ctx)); + + // secondly, create a new `offset` in the preheader, as the initial value + let init_ptr = Inst::offset(ctx, ptr, *init); + preheader.push_inst_before_terminator(ctx, init_ptr); + + // replace all the use of the old offset-ed ptr with the block parameter + let old = inst.result(ctx, 0); + for user in old.users(ctx) { + user.replace(ctx, old, new_ptr); + } + + // induce the offset + let induced_ptr = Inst::offset(ctx, new_ptr, new_step); + inst.insert_before(ctx, induced_ptr); + + let preds = header.preds(ctx); + assert_eq!(preds.len(), 2); + + for pred in preds { + let tail = pred.tail(ctx).unwrap(); + if pred == preheader { + tail.add_succ_arg(ctx, header, new_ptr, init_ptr.result(ctx, 0)); + } else { + tail.add_succ_arg(ctx, header, new_ptr, induced_ptr.result(ctx, 0)); + } + } + + changed = true; + } + } + } + + changed + } +} + +impl LocalPassMut for IndvarOffset { + type Output = (); + + fn run(&mut self, ctx: &mut Context, func: Func) -> PassResult<(Self::Output, bool)> { + let mut changed = false; + + let scevs = LocalPass::run(&mut self.scev, ctx, func)?; + + let mut loops = Vec::new(); + + for lp in self.scev.loops.loops() { + let depth = lp.depth(&self.scev.loops); + loops.push(LoopWithDepth { lp, depth }); + } + + loops.sort(); + + for LoopWithDepth { lp, .. } in loops { + changed |= self.process_loop(ctx, lp, &scevs); + } + + Ok(((), changed)) + } +} +impl GlobalPassMut for IndvarOffset { + type Output = (); + + fn run(&mut self, ctx: &mut Context) -> PassResult<(Self::Output, bool)> { + let mut changed = false; + for func in ctx.funcs() { + let (_, local_changed) = LocalPassMut::run(self, ctx, func).unwrap(); + changed |= local_changed; + } + Ok(((), changed)) + } +} + +impl TransformPass for IndvarOffset { + fn register(passman: &mut PassManager) + where + Self: Sized, + { + let pass = Self::default(); + + passman.register_transform( + INDVAR_OFFSET, + pass, + vec![ + Box::new(CfgCanonicalize), + Box::new(LoopSimplify::default()), + Box::new(Lcssa::default()), + ], + ); + + passman.add_post_dep(INDVAR_OFFSET, Box::new(SimpleDce::default())); + } +} diff --git a/src/ir/passes/loops/indvar_reduce.rs b/src/ir/passes/loops/indvar_reduce.rs new file mode 100644 index 0000000..30d7292 --- /dev/null +++ b/src/ir/passes/loops/indvar_reduce.rs @@ -0,0 +1,138 @@ +use std::collections::HashMap; + +use super::{scalar_evolution::LoopScevRecord, Lcssa, LoopSimplify, ScevAnalysis}; +use crate::{ + ir::{ + passes::control_flow::CfgCanonicalize, + passman::{GlobalPassMut, LocalPass, LocalPassMut, PassManager, PassResult, TransformPass}, + Block, + Context, + Func, + }, + utils::{ + def_use::{Usable, User}, + loop_info::{Loop, LoopWithDepth}, + }, +}; + +pub const INDVAR_REDUCE: &str = "indvar-reduce"; + +#[derive(Default)] +pub struct IndvarReduce { + scev: ScevAnalysis, +} + +impl IndvarReduce { + fn process_loop(&self, ctx: &mut Context, lp: Loop, scevs: &LoopScevRecord) -> bool { + // remove redundant induction variables. If two indvars share the same step, op, + // and init, we can replace one with the other. + + let header = lp.header(&self.scev.loops); + + let params = header.params(ctx).to_vec(); + + let mut indvar_to_replace = HashMap::new(); + + for (i, param) in params.iter().enumerate() { + if indvar_to_replace.contains_key(param) { + continue; + } + + let indvar = if let Some(indvar) = scevs.scevs.get(param) { + indvar + } else { + continue; + }; + + for &other in params.iter().skip(i + 1) { + let other = if let Some(other) = scevs.scevs.get(&other) { + other + } else { + continue; + }; + + if indvar.init == other.init + && indvar.step == other.step + && indvar.op == other.op + && indvar.modulus == other.modulus + { + indvar_to_replace.insert(other.block_param, *param); + } + } + } + + println!( + "[ indvar-reduce ] reduced {} redundant indvars", + indvar_to_replace.len() + ); + + if indvar_to_replace.is_empty() { + return false; + } + + for (from, to) in indvar_to_replace { + for user in from.users(ctx) { + user.replace(ctx, from, to); + } + } + + true + } +} + +impl LocalPassMut for IndvarReduce { + type Output = (); + + fn run(&mut self, ctx: &mut Context, func: Func) -> PassResult<(Self::Output, bool)> { + let mut changed = false; + + let scevs = LocalPass::run(&mut self.scev, ctx, func)?; + + let mut loops = Vec::new(); + + for lp in self.scev.loops.loops() { + let depth = lp.depth(&self.scev.loops); + loops.push(LoopWithDepth { lp, depth }); + } + + loops.sort(); + + for LoopWithDepth { lp, .. } in loops { + changed |= self.process_loop(ctx, lp, &scevs); + } + + Ok(((), changed)) + } +} + +impl GlobalPassMut for IndvarReduce { + type Output = (); + + fn run(&mut self, ctx: &mut Context) -> PassResult<(Self::Output, bool)> { + let mut changed = false; + for func in ctx.funcs() { + let (_, local_changed) = LocalPassMut::run(self, ctx, func).unwrap(); + changed |= local_changed; + } + Ok(((), changed)) + } +} + +impl TransformPass for IndvarReduce { + fn register(passman: &mut PassManager) + where + Self: Sized, + { + let pass = Self::default(); + + passman.register_transform( + INDVAR_REDUCE, + pass, + vec![ + Box::new(CfgCanonicalize), + Box::new(LoopSimplify::default()), + Box::new(Lcssa::default()), + ], + ); + } +} diff --git a/src/ir/passes/loops/indvar_simplify.rs b/src/ir/passes/loops/indvar_simplify.rs index 640af52..5dcd992 100644 --- a/src/ir/passes/loops/indvar_simplify.rs +++ b/src/ir/passes/loops/indvar_simplify.rs @@ -1,4 +1,10 @@ -use super::{scalar_evolution::LoopScevRecord, Lcssa, LoopSimplify, Scev, ScevAnalysis}; +use super::{ + scalar_evolution::{LoopBound, LoopBoundCond, LoopScevRecord}, + Lcssa, + LoopSimplify, + Scev, + ScevAnalysis, +}; use crate::{ collections::linked_list::LinkedListContainerPtr, ir::{ @@ -7,7 +13,6 @@ use crate::{ Block, Context, Func, - ICmpCond, InstKind, }, utils::{ @@ -28,64 +33,56 @@ impl IndvarSimplify { fn process_loop(&mut self, ctx: &mut Context, lp: Loop, scevs: &LoopScevRecord) -> bool { let mut changed = false; - let (lhs, cmp_cond, rhs) = if let Some(bound) = scevs.loop_bounds.get(&lp).unwrap() { + let LoopBound { + block_param, + cond: cmp_cond, + bound, + reversed, + .. + } = if let Some(bound) = scevs.loop_bounds.get(&lp).unwrap() { bound } else { return false; }; - match cmp_cond { - ICmpCond::Eq | ICmpCond::Ne => return false, - // TODO: Support unsigned bounds - ICmpCond::Ule | ICmpCond::Ult => return false, - ICmpCond::Sle | ICmpCond::Slt => {} + let Scev { + block_param: repr, + step, + op, + .. + } = scevs.scevs.get(block_param).unwrap(); + + let mut step_const = None; + if let Some(inst) = step.def_inst(ctx) { + if let InstKind::IConst(step) = inst.kind(ctx) { + step_const = Some(step.as_signed()); + } } - if let Some(Scev { repr, step, op, .. }) = scevs - .scevs - .get(&lp) - .unwrap() - .iter() - .find(|s| s.repr == *lhs || s.repr == *rhs) - { - let (bound, cmp_cond, reversed) = if repr == lhs { - (rhs, cmp_cond, false) - } else if repr == rhs { - (lhs, cmp_cond, true) - } else { - unreachable!() - }; - - let mut step_const = None; - if let Some(inst) = step.def_inst(ctx) { - if let InstKind::IConst(step) = inst.kind(ctx) { - step_const = Some(step.as_signed()); + if let (InductionOp::Add, LoopBoundCond::Slt, false) = (op, cmp_cond, reversed) { + // TODO: we can calculate the tripcount and simplify more. the tripcount is also + // used in unrolling. + if step_const == Some(1) { + // replace the use of the induction variable with the bound + // we can only replace the use in the header, and only the + // successor jump to exit blocks. + let header = lp.header(&self.scev.loops); + let tail = header.tail(ctx).unwrap(); + + if !tail.is_br(ctx) { + return false; } - } - - if let (InductionOp::Add, ICmpCond::Slt, false) = (op, cmp_cond, reversed) { - if step_const == Some(1) { - // replace the use of the induction variable with the bound - // we can only replace the use in the header, and only the - // successor jump to exit blocks. - let header = lp.header(&self.scev.loop_ctx); - let tail = header.tail(ctx).unwrap(); - if !tail.is_br(ctx) { - return false; - } + let blocks = lp.get_blocks(ctx, &self.scev.loops); - let blocks = lp.get_blocks(ctx, &self.scev.loop_ctx); - - if blocks.len() != 2 { - // only process the loop with 2 blocks, header and body (also the backedge). - return false; - } + if blocks.len() != 2 { + // only process the loop with 2 blocks, header and body (also the backedge). + return false; + } - tail.replace(ctx, *repr, *bound); + tail.replace(ctx, *repr, *bound); - changed = true; - } + changed = true; } } @@ -103,8 +100,8 @@ impl LocalPassMut for IndvarSimplify { let mut loops = Vec::new(); - for lp in self.scev.loop_ctx.loops() { - let depth = lp.depth(&self.scev.loop_ctx); + for lp in self.scev.loops.loops() { + let depth = lp.depth(&self.scev.loops); loops.push(LoopWithDepth { lp, depth }); } diff --git a/src/ir/passes/loops/mod.rs b/src/ir/passes/loops/mod.rs index ebd3ff4..a282d39 100644 --- a/src/ir/passes/loops/mod.rs +++ b/src/ir/passes/loops/mod.rs @@ -1,17 +1,23 @@ mod dead_loop_elim; +mod indvar_offset; +mod indvar_reduce; mod indvar_simplify; mod invariant_motion; mod lcssa; mod peel; mod scalar_evolution; mod simplify; +mod strength_reduction; mod unroll; pub use dead_loop_elim::{DeadLoopElim, DEAD_LOOP_ELIM}; +pub use indvar_offset::{IndvarOffset, INDVAR_OFFSET}; +pub use indvar_reduce::{IndvarReduce, INDVAR_REDUCE}; pub use indvar_simplify::{IndvarSimplify, INDVAR_SIMPLIFY}; pub use invariant_motion::{LoopInvariantMotion, LOOP_INVARIANT_MOTION}; pub use lcssa::{Lcssa, LCSSA}; pub use peel::{LoopPeel, LOOP_PEEL}; pub use scalar_evolution::{InductionOp, Scev, ScevAnalysis}; pub use simplify::{LoopSimplify, LOOP_SIMPLIFY}; +pub use strength_reduction::{LoopStrengthReduction, LOOP_STRENGTH_REDUCTION}; pub use unroll::{LoopUnroll, LOOP_UNROLL}; diff --git a/src/ir/passes/loops/scalar_evolution.rs b/src/ir/passes/loops/scalar_evolution.rs index 816f073..8fd3c50 100644 --- a/src/ir/passes/loops/scalar_evolution.rs +++ b/src/ir/passes/loops/scalar_evolution.rs @@ -9,18 +9,20 @@ use crate::{ Func, IBinaryOp, ICmpCond, + Inst, InstKind, Value, ValueKind, }, utils::{ - cfg::CfgInfo, + cfg::{CfgInfo, CfgNode}, def_use::Usable, dominance::Dominance, loop_info::{Loop, LoopContext}, }, }; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum InductionOp { Add, Sub, @@ -45,13 +47,14 @@ impl TryFrom for InductionOp { } } -/// A record of an induction variable. +/// Basic recurrence record. +#[derive(Debug, Clone)] pub struct Scev { /// The representative value of this induction variable, typically the loop /// parameter. - pub repr: Value, + pub block_param: Value, /// The start value of this induction variable. - pub start: Value, + pub init: Value, /// The evolving step of this induction variable. pub step: Value, /// The evolution method of this induction variable. @@ -60,6 +63,29 @@ pub struct Scev { pub modulus: Option, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum LoopBoundCond { + Slt, + Sle, + Sgt, + Sge, +} + +/// A record for loop bound. +pub struct LoopBound { + /// The block parameter representing the basic recurrence. + pub block_param: Value, + /// The instruction for comparison + pub cmp_inst: Inst, + /// The comparison condition + pub cond: LoopBoundCond, + /// The bound value + pub bound: Value, + /// If this is a reversed comparison, i.e., bound is the lhs in the + /// instruction. + pub reversed: bool, +} + pub struct DisplayScev<'a> { ctx: &'a Context, scev: &'a Scev, @@ -67,7 +93,7 @@ pub struct DisplayScev<'a> { impl<'a> std::fmt::Display for DisplayScev<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let start = self.scev.start.name(self.ctx).unwrap(); + let start = self.scev.init.name(self.ctx).unwrap(); let step = self.scev.step.name(self.ctx).unwrap(); let modulus = self @@ -101,18 +127,16 @@ impl Scev { /// operations. #[derive(Default)] pub struct ScevAnalysis { - pub(super) loop_ctx: LoopContext, + pub(super) loops: LoopContext, pub(super) dominance: Dominance, - // TODO: union-find might be useful for some complex tree-shaped operations, and maybe we can - // re-associate the operations to get a indvar expression. - indvars: Vec, + indvars: FxHashMap, } impl ScevAnalysis { fn process_loop(&mut self, ctx: &Context, lp: Loop) { // after loop-simplify, the header should have exactly two predecessors. one is // the preheader, the other is the latch. - let header = lp.header(&self.loop_ctx); + let header = lp.header(&self.loops); for param in header.params(ctx) { // all the block params in the header are suspected to be induction variables. @@ -141,7 +165,7 @@ impl ScevAnalysis { let pred_block = pred_inst.container(ctx).unwrap(); - if self.loop_ctx.is_in_loop(pred_block, lp) { + if self.loops.is_in_loop(pred_block, lp) { // the incoming value is from inside the loop, it should be the evolving // value of the induction variable. evolving = Some(incoming); @@ -159,15 +183,15 @@ impl ScevAnalysis { } // get the most initial start value. - let mut start = start.unwrap(); + let mut init = start.unwrap(); // the start can be a block param in the preheader, so get the def block of the // start value, if the start is a block param and there is only one predecessor, // get the incoming value. - while let ValueKind::BlockParam { block, .. } = start.kind(ctx) { + while let ValueKind::BlockParam { block, .. } = init.kind(ctx) { if block.preds(ctx).len() == 1 { // one predecessor -> one inst & one succ in the inst -> just get the 0-th user if let Some(succ) = block.users(ctx)[0].succ_to(ctx, *block).next() { - start = succ.get_arg(start).unwrap(); + init = succ.get_arg(init).unwrap(); break; } } else { @@ -186,7 +210,7 @@ impl ScevAnalysis { // // The instruction should be `%evolving = %param, %step`, where `%step` is // an loop-invariant value. Here we simply check if the `%step` is defined - // outside the loop, and let LICM to do the preparation + // outside the loop, and let LICM/GCM to do the preparation let def_inst = if let Some(def_inst) = evolving.def_inst(ctx) { def_inst } else { @@ -208,7 +232,7 @@ impl ScevAnalysis { continue; }; - if self.loop_ctx.is_in_loop(step.def_block(ctx), lp) { + if self.loops.is_in_loop(step.def_block(ctx), lp) { // the step is not defined outside the loop, this is not an induction variable. // again, this is a conservative approach to check, we should let LICM to do the // preparation. @@ -228,39 +252,96 @@ impl ScevAnalysis { }; let indvar = Scev { - repr: *param, - start, + block_param: *param, + init, step, op: ind_op, modulus: None, }; - self.indvars.push(indvar); + self.indvars.insert(*param, indvar); } } } - pub fn find_loop_bound( - &mut self, - ctx: &Context, - lp: Loop, - ) -> Option<(Value, ICmpCond, Value)> { - let header = lp.header(&self.loop_ctx); - // TODO: is this right? + pub fn find_loop_bound(&mut self, ctx: &Context, lp: Loop) -> Option { + let exits = lp.get_exit_blocks(ctx, &self.loops); + + if exits.len() != 1 { + // we only support single exit loop. + println!("[ scev-analysis ] multiple exits in loop, not supported."); + return None; + } + + let header = lp.header(&self.loops); + + let mut pre_exit = false; + for succ in header.succs(ctx) { + if exits.contains(&succ) { + pre_exit = true; + break; + } + } + if !pre_exit { + println!("[ scev-analysis ] header is not pre-exit block, not supported"); + return None; + } + // this should be a branch instruction. let tail = header.tail(ctx).unwrap(); + let mut loop_bound = None; if let InstKind::Br = tail.kind(ctx) { let cond = tail.operand(ctx, 0); if let Some(inst) = cond.def_inst(ctx) { - if let InstKind::IBinary(IBinaryOp::Cmp(cmp_cond)) = inst.kind(ctx) { + if let InstKind::IBinary(IBinaryOp::Cmp(cond @ (ICmpCond::Slt | ICmpCond::Sle))) = + inst.kind(ctx) + { let lhs = inst.operand(ctx, 0); let rhs = inst.operand(ctx, 1); - // not checing if lhs/rhs is the block param/indvar - loop_bound = Some((lhs, *cmp_cond, rhs)); + // the compare form should be: + // 1. indvar invariant + // 2. invariant indvar (revsersed) + let (indvar, invariant, reversed) = if self.indvars.contains_key(&lhs) { + // also check if the bound is loop-invariant. + let def_block = rhs.def_block(ctx); + if self.loops.is_in_loop(def_block, lp) { + return None; + } + (lhs, rhs, false) + } else if self.indvars.contains_key(&rhs) { + let def_block = lhs.def_block(ctx); + if self.loops.is_in_loop(def_block, lp) { + return None; + } + (rhs, lhs, true) + } else { + // the compare is not related to the loop parameter. + return None; + }; + + let cond = if *cond == ICmpCond::Slt && !reversed { + LoopBoundCond::Slt + } else if *cond == ICmpCond::Sle && !reversed { + LoopBoundCond::Sle + } else if *cond == ICmpCond::Slt && reversed { + LoopBoundCond::Sgt + } else if *cond == ICmpCond::Sle && reversed { + LoopBoundCond::Sge + } else { + unreachable!() + }; + + loop_bound = Some(LoopBound { + block_param: indvar, + cmp_inst: inst, + cond, + bound: invariant, + reversed, + }); } } } @@ -272,15 +353,9 @@ impl ScevAnalysis { #[derive(Default)] pub struct LoopScevRecord { /// The loop parameter. - pub loop_bounds: FxHashMap, Option<(Value, ICmpCond, Value)>>, + pub loop_bounds: FxHashMap, Option>, /// The detected loop induction variables. - pub scevs: FxHashMap, Vec>, -} - -impl LoopScevRecord { - pub fn iter(&self) -> impl Iterator { - self.scevs.iter().flat_map(|(_, v)| v.iter()) - } + pub scevs: FxHashMap, } pub struct DisplayLoopScevRecord<'a> { @@ -290,23 +365,39 @@ pub struct DisplayLoopScevRecord<'a> { impl<'a> std::fmt::Display for DisplayLoopScevRecord<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - for (lp, bound) in self.record.loop_bounds.iter() { - if let Some(bound) = bound { - let (lhs, cond, rhs) = bound; + for (_, bound) in self.record.loop_bounds.iter() { + if let Some(LoopBound { + block_param, + cond, + bound, + reversed, + .. + }) = bound + { + let param_name = block_param.name(self.ctx).unwrap(); + let bound_name = bound.name(self.ctx).unwrap(); + + let cond = match cond { + LoopBoundCond::Slt => "slt", + LoopBoundCond::Sle => "sle", + LoopBoundCond::Sgt => "sgt", + LoopBoundCond::Sge => "sge", + }; + writeln!( f, - "loop bound: {} {} {}", - lhs.name(self.ctx).unwrap(), + "{} {} {} {}", + param_name, cond, - rhs.name(self.ctx).unwrap() + bound_name, + if *reversed { "reversed" } else { "" } )?; - } else { - writeln!(f, "loop bound: none")?; } + } - for scev in self.record.scevs.get(lp).unwrap() { - writeln!(f, " {}", scev.display(self.ctx))?; - } + for (param, scev) in self.record.scevs.iter() { + let param_name = param.name(self.ctx).unwrap(); + writeln!(f, "{}: {}", param_name, scev.display(self.ctx))?; } Ok(()) @@ -326,18 +417,19 @@ impl LocalPass for ScevAnalysis { let cfg = CfgInfo::new(ctx, func); self.dominance = Dominance::new(ctx, &cfg); - self.loop_ctx = LoopContext::new(&cfg, &self.dominance); + self.loops = LoopContext::new(&cfg, &self.dominance); self.indvars.clear(); let mut result = LoopScevRecord::default(); - for lp in self.loop_ctx.loops() { + for lp in self.loops.loops() { self.process_loop(ctx, lp); let bound = self.find_loop_bound(ctx, lp); result.loop_bounds.insert(lp, bound); - result.scevs.insert(lp, self.indvars.drain(..).collect()); } + result.scevs = self.indvars.drain().collect(); + Ok(result) } } diff --git a/src/ir/passes/loops/strength_reduction.rs b/src/ir/passes/loops/strength_reduction.rs new file mode 100644 index 0000000..ea69499 --- /dev/null +++ b/src/ir/passes/loops/strength_reduction.rs @@ -0,0 +1,192 @@ +use super::{scalar_evolution::LoopScevRecord, InductionOp, Lcssa, LoopSimplify, ScevAnalysis}; +use crate::{ + collections::linked_list::{LinkedListContainerPtr, LinkedListNodePtr}, + ir::{ + passes::{control_flow::CfgCanonicalize, loops::Scev, simple_dce::SimpleDce}, + passman::{GlobalPassMut, LocalPass, LocalPassMut, PassManager, PassResult, TransformPass}, + Block, + Context, + Func, + IBinaryOp, + Inst, + InstKind, + }, + utils::{ + def_use::{Usable, User}, + loop_info::{Loop, LoopWithDepth}, + }, +}; + +pub const LOOP_STRENGTH_REDUCTION: &str = "loop-strength-reduction"; + +#[derive(Default)] +pub struct LoopStrengthReduction { + /// Scalar evolution analysis. + scev: ScevAnalysis, +} + +impl LoopStrengthReduction { + fn process_loop(&self, ctx: &mut Context, lp: Loop, scevs: &LoopScevRecord) -> bool { + let preheader = lp.get_preheader(ctx, &self.scev.loops).unwrap(); + let header = lp.header(&self.scev.loops); + + let header_preds = header.preds(ctx); + assert_eq!(header_preds.len(), 2); // preheader and the backedge. + + let backedge = *header_preds + .iter() + .find(|&&pred| pred != preheader) + .unwrap(); + + let mut changed = false; + + let block_params = header.params(ctx).to_vec(); + + for block_param in block_params { + let Scev { init, step, op, .. } = if let Some(scev) = scevs.scevs.get(&block_param) { + scev + } else { + continue; + }; + + let ibinary_op = match op { + InductionOp::Add => IBinaryOp::Add, + InductionOp::Sub => IBinaryOp::Sub, + InductionOp::Mul | InductionOp::SDiv | InductionOp::Shl => return false, + }; + + for user in block_param.users(ctx) { + if !self + .scev + .dominance + .dominates(user.container(ctx).unwrap(), backedge) + { + // user must dominate the backedge block, so the induction is correct. + continue; + } + + if let InstKind::IBinary(IBinaryOp::Mul) = user.kind(ctx) { + let lhs = user.operand(ctx, 0); + let rhs = user.operand(ctx, 1); + + let k = if lhs == block_param { + rhs + } else if rhs == block_param { + lhs + } else { + unreachable!() + }; + + let k_def_block = k.def_block(ctx); + if self.scev.loops.is_in_loop(k_def_block, lp) { + // k is not a loop invariant + continue; + } + + // create init * k in the preheader + let mul = Inst::ibinary(ctx, IBinaryOp::Mul, *init, k); + preheader.push_inst_before_terminator(ctx, mul); + + // create new block parameter for the loop + let new_block_param = header.new_param(ctx, lhs.ty(ctx)); + + // also need step * k as the new step for the mul + let new_step = Inst::ibinary(ctx, IBinaryOp::Mul, *step, k); + preheader.push_inst_before_terminator(ctx, new_step); + + // inside the loop, use the new block parameter and addition + let induction_inst = + Inst::ibinary(ctx, ibinary_op, new_block_param, new_step.result(ctx, 0)); + user.insert_after(ctx, induction_inst); + + let old = user.result(ctx, 0); + for user in old.users(ctx) { + // use before add + user.replace(ctx, old, new_block_param); + } + + let preds = header.preds(ctx); + // single back-edge is guaranteed by loop-simplify, so there should be only + // two preds, preheader and backedge. + assert_eq!(preds.len(), 2); + + for pred in preds { + let tail = pred.tail(ctx).unwrap(); + if pred == preheader { + tail.add_succ_arg(ctx, header, new_block_param, mul.result(ctx, 0)); + } else { + tail.add_succ_arg( + ctx, + header, + new_block_param, + induction_inst.result(ctx, 0), + ); + } + } + + changed = true; + } + } + } + + changed + } +} + +impl LocalPassMut for LoopStrengthReduction { + type Output = (); + + fn run(&mut self, ctx: &mut Context, func: Func) -> PassResult<(Self::Output, bool)> { + let mut changed = false; + + let scevs = LocalPass::run(&mut self.scev, ctx, func)?; + + let mut loops = Vec::new(); + + for lp in self.scev.loops.loops() { + let depth = lp.depth(&self.scev.loops); + loops.push(LoopWithDepth { lp, depth }); + } + + loops.sort(); + + for LoopWithDepth { lp, .. } in loops { + changed |= self.process_loop(ctx, lp, &scevs); + } + + Ok(((), changed)) + } +} +impl GlobalPassMut for LoopStrengthReduction { + type Output = (); + + fn run(&mut self, ctx: &mut Context) -> PassResult<(Self::Output, bool)> { + let mut changed = false; + for func in ctx.funcs() { + let (_, local_changed) = LocalPassMut::run(self, ctx, func).unwrap(); + changed |= local_changed; + } + Ok(((), changed)) + } +} + +impl TransformPass for LoopStrengthReduction { + fn register(passman: &mut PassManager) + where + Self: Sized, + { + let pass = Self::default(); + + passman.register_transform( + LOOP_STRENGTH_REDUCTION, + pass, + vec![ + Box::new(CfgCanonicalize), + Box::new(LoopSimplify::default()), + Box::new(Lcssa::default()), + ], + ); + + passman.add_post_dep(LOOP_STRENGTH_REDUCTION, Box::new(SimpleDce::default())); + } +} diff --git a/src/ir/passes/loops/unroll.rs b/src/ir/passes/loops/unroll.rs index a0b1b27..6a107a1 100644 --- a/src/ir/passes/loops/unroll.rs +++ b/src/ir/passes/loops/unroll.rs @@ -1,10 +1,19 @@ -use super::{scalar_evolution::LoopScevRecord, Lcssa, LoopSimplify, Scev, ScevAnalysis}; +use super::{ + scalar_evolution::{LoopBound, LoopScevRecord}, + Lcssa, + LoopSimplify, + Scev, + ScevAnalysis, +}; use crate::{ collections::linked_list::{LinkedListContainerPtr, LinkedListNodePtr}, ir::{ debug::CommentPos, deep_clone::DeepCloneMap, - passes::{control_flow::CfgCanonicalize, loops::InductionOp}, + passes::{ + control_flow::CfgCanonicalize, + loops::{scalar_evolution::LoopBoundCond, InductionOp}, + }, passman::{ GlobalPassMut, LocalPass, @@ -49,18 +58,25 @@ impl LoopUnroll { fn process_loop(&mut self, ctx: &mut Context, lp: Loop, scevs: &LoopScevRecord) -> bool { let mut changed = false; - let (lhs, cmp_cond, rhs) = if let Some(bound) = scevs.loop_bounds.get(&lp).unwrap() { + let LoopBound { + block_param, + cond: cmp_cond, + bound, + reversed, + .. + } = if let Some(bound) = scevs.loop_bounds.get(&lp).unwrap() { bound } else { return false; }; - match cmp_cond { - ICmpCond::Eq | ICmpCond::Ne => return false, - // TODO: Support unsigned bounds - ICmpCond::Ule | ICmpCond::Ult => return false, - ICmpCond::Sle | ICmpCond::Slt => {} - } + let Scev { + block_param: repr, + init: start, + step, + op, + modulus, + } = scevs.scevs.get(block_param).unwrap(); // trip count calculation (for sle and addition) // let trip count as k, init = a, upper = b @@ -71,106 +87,85 @@ impl LoopUnroll { // for slt, just minus the upper bound by 1 // iterate the indvars, check if we can unroll the loop - if let Some(Scev { - repr, - start, - step, - op, - modulus, - }) = scevs - .scevs - .get(&lp) - .unwrap() - .iter() - .find(|s| s.repr == *lhs || s.repr == *rhs) - { - let (bound, cmp_cond, reversed) = if repr == lhs { - (rhs, cmp_cond, false) - } else if repr == rhs { - (lhs, cmp_cond, true) - } else { - unreachable!() - }; - - println!("[ loop-unroll ] start: {:?}", start); - - let mut start_const = None; - if let Some(inst) = start.def_inst(ctx) { - if let InstKind::IConst(start) = inst.kind(ctx) { - start_const = Some(start.as_signed()); - } + + println!("[ loop-unroll ] start: {:?}", start); + + let mut start_const = None; + if let Some(inst) = start.def_inst(ctx) { + if let InstKind::IConst(start) = inst.kind(ctx) { + start_const = Some(start.as_signed()); } + } - let mut step_const = None; - if let Some(inst) = step.def_inst(ctx) { - if let InstKind::IConst(step) = inst.kind(ctx) { - step_const = Some(step.as_signed()); - } + let mut step_const = None; + if let Some(inst) = step.def_inst(ctx) { + if let InstKind::IConst(step) = inst.kind(ctx) { + step_const = Some(step.as_signed()); } + } + + let mut bound_const = None; + if let Some(inst) = bound.def_inst(ctx) { + if let InstKind::IConst(bound) = inst.kind(ctx) { + bound_const = Some(bound.as_signed()); + } + } - let mut bound_const = None; - if let Some(inst) = bound.def_inst(ctx) { - if let InstKind::IConst(bound) = inst.kind(ctx) { - bound_const = Some(bound.as_signed()); + // two layers of option, outer means if there is a modulus, inner means if the + // modulus is a constant + let modulus_const = modulus.map(|v| { + let mut modulus_const = None; + if let Some(inst) = v.def_inst(ctx) { + if let InstKind::IConst(modulus) = inst.kind(ctx) { + modulus_const = Some(modulus.as_signed()); } } + modulus_const + }); - // two layers of option, outer means if there is a modulus, inner means if the - // modulus is a constant - let modulus_const = modulus.map(|v| { - let mut modulus_const = None; - if let Some(inst) = v.def_inst(ctx) { - if let InstKind::IConst(modulus) = inst.kind(ctx) { - modulus_const = Some(modulus.as_signed()); - } + let trip_count_const = if let (Some(start), Some(step), Some(bound), None) = + (start_const, step_const, bound_const, modulus_const) + { + match (op, cmp_cond, reversed) { + (InductionOp::Add, LoopBoundCond::Sle, false) => { + let trip_count = (bound - start) / step + 1; + Some(trip_count as usize) } - modulus_const - }); - - let trip_count_const = if let (Some(start), Some(step), Some(bound), None) = - (start_const, step_const, bound_const, modulus_const) - { - match (op, cmp_cond, reversed) { - (InductionOp::Add, ICmpCond::Sle, false) => { - let trip_count = (bound - start) / step + 1; - Some(trip_count as usize) - } - (InductionOp::Add, ICmpCond::Slt, false) => { - let trip_count = (bound - 1 - start) / step + 1; - Some(trip_count as usize) - } - // TODO: support more - _ => None, + (InductionOp::Add, LoopBoundCond::Slt, false) => { + let trip_count = (bound - 1 - start) / step + 1; + Some(trip_count as usize) } - } else { - None - }; - - println!("[ loop-unroll (const) ] start: {:?}", start_const); - println!("[ loop-unroll (const) ] step: {:?}", step_const); - println!("[ loop-unroll (const) ] bound: {:?}", bound_const); - println!("[ loop-unroll (const) ] trip count: {:?}", trip_count_const); - - if let Some(_trip_count) = trip_count_const { - changed |= self.unroll_const( - ctx, - lp, - *repr, - start_const.unwrap(), - step_const.unwrap(), - bound_const.unwrap(), - trip_count_const.unwrap(), - ); - } else { - // dynamic unrolling - #[allow(clippy::single_match)] - match (op, cmp_cond, reversed) { - (InductionOp::Add, ICmpCond::Sle | ICmpCond::Slt, false) => { - changed |= - self.unroll_dynamic(ctx, lp, *repr, *start, *step, *bound, *cmp_cond); - } - _ => {} + // TODO: support more + _ => None, + } + } else { + None + }; + + println!("[ loop-unroll (const) ] start: {:?}", start_const); + println!("[ loop-unroll (const) ] step: {:?}", step_const); + println!("[ loop-unroll (const) ] bound: {:?}", bound_const); + println!("[ loop-unroll (const) ] trip count: {:?}", trip_count_const); + + if let Some(_trip_count) = trip_count_const { + changed |= self.unroll_const( + ctx, + lp, + *repr, + start_const.unwrap(), + step_const.unwrap(), + bound_const.unwrap(), + trip_count_const.unwrap(), + ); + } else { + // dynamic unrolling + #[allow(clippy::single_match)] + match (op, cmp_cond, reversed) { + (InductionOp::Add, LoopBoundCond::Sle | LoopBoundCond::Slt, false) => { + changed |= + self.unroll_dynamic(ctx, lp, *repr, *start, *step, *bound, *cmp_cond); } + _ => {} } } @@ -186,27 +181,27 @@ impl LoopUnroll { start: Value, step: Value, bound: Value, - cmp_cond: ICmpCond, + cmp_cond: LoopBoundCond, ) -> bool { let bound_def_block = bound.def_block(ctx); - if self.scev.loop_ctx.is_in_loop(bound_def_block, lp) { + if self.scev.loops.is_in_loop(bound_def_block, lp) { // not loop invariant, cannot unroll return false; } let step_def_block = step.def_block(ctx); - if self.scev.loop_ctx.is_in_loop(step_def_block, lp) { + if self.scev.loops.is_in_loop(step_def_block, lp) { // not loop invariant, cannot unroll return false; } - let header = lp.header(&self.scev.loop_ctx); + let header = lp.header(&self.scev.loops); let mut is_pre_exit = false; - let exits = lp.get_exit_blocks(ctx, &self.scev.loop_ctx); + let exits = lp.get_exit_blocks(ctx, &self.scev.loops); for succ in header.succs(ctx) { - if !self.scev.loop_ctx.is_in_loop(succ, lp) { + if !self.scev.loops.is_in_loop(succ, lp) { is_pre_exit = true; break; } @@ -240,8 +235,8 @@ impl LoopUnroll { // } // ... old loop body with new i - let preheader = lp.get_preheader(ctx, &self.scev.loop_ctx).unwrap(); - let blocks = lp.get_blocks(ctx, &self.scev.loop_ctx); + let preheader = lp.get_preheader(ctx, &self.scev.loops).unwrap(); + let blocks = lp.get_blocks(ctx, &self.scev.loops); let insn: usize = blocks.iter().map(|bb| bb.insn(ctx)).sum(); if insn * self.unroll_factor > 512 { @@ -252,15 +247,15 @@ impl LoopUnroll { // calculate the trip count and unroll count let bound = match cmp_cond { - ICmpCond::Slt => { + LoopBoundCond::Slt => { let iconst = Inst::iconst(ctx, 1, ty); let sub = Inst::ibinary(ctx, IBinaryOp::Sub, bound, iconst.result(ctx, 0)); preheader.push_inst_before_terminator(ctx, iconst); preheader.push_inst_before_terminator(ctx, sub); sub.result(ctx, 0) } - ICmpCond::Sle => bound, - ICmpCond::Eq | ICmpCond::Ne | ICmpCond::Ult | ICmpCond::Ule => unimplemented!(), + LoopBoundCond::Sle => bound, + LoopBoundCond::Sgt | LoopBoundCond::Sge => unimplemented!("not supported"), }; // calculate the trip count let sub = Inst::ibinary(ctx, IBinaryOp::Sub, bound, start); @@ -428,12 +423,12 @@ impl LoopUnroll { self.deep_clone_map.clear(); - let header = lp.header(&self.scev.loop_ctx); + let header = lp.header(&self.scev.loops); let mut curr_header = header; - let preheader = lp.get_preheader(ctx, &self.scev.loop_ctx).unwrap(); + let preheader = lp.get_preheader(ctx, &self.scev.loops).unwrap(); - let blocks = lp.get_blocks(ctx, &self.scev.loop_ctx); + let blocks = lp.get_blocks(ctx, &self.scev.loops); let insn: usize = blocks.iter().map(|bb| bb.insn(ctx)).sum(); @@ -558,10 +553,13 @@ impl LocalPassMut for LoopUnroll { fn run(&mut self, ctx: &mut Context, func: Func) -> PassResult<(Self::Output, bool)> { let scevs = LocalPass::run(&mut self.scev, ctx, func)?; + ctx.alloc_all_names(); + println!("{}", scevs.display(ctx)); + let mut loops = Vec::new(); - for lp in self.scev.loop_ctx.loops() { - let depth = lp.depth(&self.scev.loop_ctx); + for lp in self.scev.loops.loops() { + let depth = lp.depth(&self.scev.loops); loops.push(LoopWithDepth { lp, depth }); } @@ -622,5 +620,7 @@ impl TransformPass for LoopUnroll { passman.add_parameter("unroll-factor", 4); passman.add_parameter("unroll-constant-all", true); + + passman.add_post_dep(LOOP_UNROLL, Box::new(CfgCanonicalize)); } } diff --git a/src/ir/passes/mod.rs b/src/ir/passes/mod.rs index a8e458c..7927374 100644 --- a/src/ir/passes/mod.rs +++ b/src/ir/passes/mod.rs @@ -14,6 +14,5 @@ pub mod instcombine; pub mod legalize; pub mod loops; pub mod mem2reg; -pub mod side_effect; pub mod simple_dce; pub mod tco; diff --git a/src/ir/passes/side_effect.rs b/src/ir/passes/side_effect.rs deleted file mode 100644 index 9fb73b7..0000000 --- a/src/ir/passes/side_effect.rs +++ /dev/null @@ -1,23 +0,0 @@ -//! Pass for Side Effect Analysis -//! -//! This pass can analyze the side effects of a function. It is useful when -//! running dead code elimination, global value numbering, etc. -//! -//! We consider a function has side effects if it has any of the following: -//! - It is the `main` function -//! - It writes to global slots or any unknown memory locations -//! - It calls other functions that have side effects -//! - It uses any indirect calls -//! -//! Note that all `decl`-ed functions are assumed to have side effects, unless -//! explicitly marked. -//! -//! There can be recursion (not in SysY, because it has no declaration of -//! functions, but in C) in the function call graph. We need to handle this case -//! carefully. The most straightforward way is to first find all SCCs in the -//! call graph, and then mark all functions in an SCC as having side effects if -//! any function in the SCC has side effects. -//! -//! TODO: Implement this pass - -pub struct SideEffectAnalysis; diff --git a/src/ir/passes/tco.rs b/src/ir/passes/tco.rs index b88c913..8025b0b 100644 --- a/src/ir/passes/tco.rs +++ b/src/ir/passes/tco.rs @@ -1,4 +1,4 @@ -use super::control_flow::CfgSimplify; +use super::control_flow::{CfgCanonicalize, CfgSimplify}; use crate::{ collections::linked_list::{LinkedListContainerPtr, LinkedListNodePtr}, ir::{ @@ -117,5 +117,6 @@ impl GlobalPassMut for Tco { impl TransformPass for Tco { fn register(passman: &mut PassManager) { passman.register_transform(TCO, Tco, vec![Box::new(CfgSimplify)]); + passman.add_post_dep(TCO, Box::new(CfgCanonicalize)); } } diff --git a/src/ir/passman.rs b/src/ir/passman.rs index 00f58fe..88f1f4b 100644 --- a/src/ir/passman.rs +++ b/src/ir/passman.rs @@ -156,6 +156,7 @@ pub struct PassManager { parameters: ParamStorage, transforms: FxHashMap>, deps: FxHashMap>>, + post_deps: FxHashMap>>, } #[derive(Default)] @@ -189,6 +190,11 @@ impl PassManager { self.deps.insert(name, deps); } + pub fn add_post_dep(&mut self, name: impl Into, dep: Box) { + let name = name.into(); + self.post_deps.entry(name).or_default().push(dep); + } + pub fn gather_transform_names(&self) -> Vec { let mut names: Vec = self .transforms @@ -215,29 +221,34 @@ impl PassManager { ctx: &mut Context, max_iter: usize, ) -> usize { - let mut iter = 0; let name = name.into(); - for _ in 0..max_iter { - iter += 1; - let mut changed = false; - let deps = &mut self.deps; - let transforms = &mut self.transforms; - let params = &self.parameters; + let deps = &mut self.deps; + let transforms = &mut self.transforms; + let params = &self.parameters; - for pass in deps.get_mut(&name).unwrap() { - GlobalPassMut::fetch_params(pass.as_mut(), params); - let (_, local_changed) = GlobalPassMut::run(pass.as_mut(), ctx).unwrap(); - changed |= local_changed; + let mut iter = 0; + let transform = transforms.get_mut(&name).unwrap(); + GlobalPassMut::fetch_params(transform.as_mut(), params); + + self.post_deps.entry(name.clone()).or_default(); + + while iter < max_iter { + iter += 1; + for dep in deps.get_mut(&name).unwrap() { + GlobalPassMut::fetch_params(dep.as_mut(), params); + let _ = GlobalPassMut::run(dep.as_mut(), ctx).unwrap(); } - let transform = transforms.get_mut(&name).unwrap(); - GlobalPassMut::fetch_params(transform.as_mut(), params); let (_, local_changed) = GlobalPassMut::run(transform.as_mut(), ctx).unwrap(); - changed |= local_changed; - if !changed { + if !local_changed { break; } + for post_dep in self.post_deps.get_mut(&name).unwrap() { + GlobalPassMut::fetch_params(post_dep.as_mut(), params); + let _ = GlobalPassMut::run(post_dep.as_mut(), ctx).unwrap(); + } } + iter } diff --git a/tests/test_ir_fold.rs b/tests/test_ir_fold.rs index e315f7e..7a9530f 100644 --- a/tests/test_ir_fold.rs +++ b/tests/test_ir_fold.rs @@ -64,4 +64,4 @@ fn test_ir_fold_float() { panic!("test failed: {:?}", e); } } -} \ No newline at end of file +} diff --git a/tests/test_ir_instcombine.rs b/tests/test_ir_instcombine.rs index 5e568bf..3db1fc7 100644 --- a/tests/test_ir_instcombine.rs +++ b/tests/test_ir_instcombine.rs @@ -2,7 +2,7 @@ use orzcc::{ collections::diagnostic::RenderOptions, frontend::ir::{into_ir, Parser}, ir::{ - passes::instcombine::{InstCombine, INSTCOMBINE}, + passes::instcombine::{Instcombine, INSTCOMBINE}, passman::{PassManager, TransformPass}, }, }; @@ -28,7 +28,7 @@ fn test_ir_instcombine_mul_to_shl() { let mut passman = PassManager::default(); - InstCombine::register(&mut passman); + Instcombine::register(&mut passman); assert_eq!(passman.run_transform(INSTCOMBINE, &mut ctx, 1), 1); @@ -57,7 +57,7 @@ fn test_ir_instcombine_mv_const_rhs() { let mut passman = PassManager::default(); - InstCombine::register(&mut passman); + Instcombine::register(&mut passman); assert_eq!(passman.run_transform(INSTCOMBINE, &mut ctx, 1), 1); @@ -86,7 +86,7 @@ fn test_ir_instcombine_add_zero_elim() { let mut passman = PassManager::default(); - InstCombine::register(&mut passman); + Instcombine::register(&mut passman); assert_eq!(passman.run_transform(INSTCOMBINE, &mut ctx, 3), 2);