Skip to content

Commit

Permalink
When HIR auto-refs a comparison operator, clean it up by dereffing in…
Browse files Browse the repository at this point in the history
… MIR

Today, if you're comparing `&&T`s, it ends up auto-reffing in HIR.  So the MIR ends up calling `PartialEq/Cmp` with `&&&T`, and the MIR inliner can only get that down to `&T`: <https://rust.godbolt.org/z/hje6jd4Yf>.

So this adds an always-run pass to look at `Call`s in MIR with `from_hir_call: false` to just call the correct `Partial{Eq,Cmp}` implementation directly, even if it's debug and we're not running the inliner, to avoid needing to ever monomorphize a bunch of useless forwarding impls.

This hopes to avoid ever needing something like rust-lang#108372 where we'd tell people to manually dereference the sides of their comparisons.
  • Loading branch information
scottmcm committed Mar 18, 2023
1 parent e4b9f86 commit 7b74759
Show file tree
Hide file tree
Showing 11 changed files with 910 additions and 213 deletions.
10 changes: 7 additions & 3 deletions compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1931,9 +1931,13 @@ impl<'tcx> Operand<'tcx> {
///
/// While this is unlikely in general, it's the normal case of what you'll
/// find as the `func` in a [`TerminatorKind::Call`].
pub fn const_fn_def(&self) -> Option<(DefId, SubstsRef<'tcx>)> {
let const_ty = self.constant()?.literal.ty();
if let ty::FnDef(def_id, substs) = *const_ty.kind() { Some((def_id, substs)) } else { None }
pub fn const_fn_def(&self) -> Option<(DefId, SubstsRef<'tcx>, Span)> {
let constant = self.constant()?;
if let ty::FnDef(def_id, substs) = *constant.literal.ty().kind() {
Some((def_id, substs, constant.span))
} else {
None
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_mir_transform/src/instcombine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ impl<'tcx> InstCombineContext<'tcx, '_> {
else { return };

// Only bother looking more if it's easy to know what we're calling
let Some((fn_def_id, fn_substs)) = func.const_fn_def()
let Some((fn_def_id, fn_substs, _span)) = func.const_fn_def()
else { return };

// Clone needs one subst, so we can cheaply rule out other stuff
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ mod ssa;
pub mod simplify;
mod simplify_branches;
mod simplify_comparison_integral;
mod simplify_ref_comparisons;
mod sroa;
mod uninhabited_enum_branching;
mod unreachable_prop;
Expand Down Expand Up @@ -497,6 +498,8 @@ fn run_analysis_cleanup_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&cleanup_post_borrowck::CleanupPostBorrowck,
&remove_noop_landing_pads::RemoveNoopLandingPads,
&simplify::SimplifyCfg::new("early-opt"),
// Adds more `Deref`s, so needs to be before `Derefer`.
&simplify_ref_comparisons::SimplifyRefComparisons,
&deref_separator::Derefer,
];

Expand Down
86 changes: 86 additions & 0 deletions compiler/rustc_mir_transform/src/simplify_ref_comparisons.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
use crate::MirPass;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, Ty, TyCtxt};

/// This pass replaces `x OP y` with `*x OP *y` when `OP` is a comparison operator.
///
/// The goal is to make is so that it's never better for the user to write
/// `***x == ***y` than to write the obvious `x == y` (when `x` and `y` are
/// references and thus those do the same thing). This is particularly
/// important because the type-checker will auto-ref any comparison that's not
/// done directly on a primitive. That means that `a_ref == b_ref` doesn't
/// become `PartialEq::eq(a_ref, b_ref)`, even though that would work, but rather
/// ```no_run
/// # fn foo(a_ref: &i32, b_ref: &i32) -> bool {
/// let temp1 = &a_ref;
/// let temp2 = &b_ref;
/// PartialEq::eq(temp1, temp2)
/// # }
/// ```
/// Thus this pass means it directly calls the *interesting* `impl` directly,
/// rather than needing to monomorphize and/or inline it later. (And when this
/// comment was written in March 2023, the MIR inliner seemed to only inline
/// one level of `==`, so if the comparison is on something like `&&i32` the
/// extra forwarding impls needed to be monomorphized even in an optimized build.)
///
/// Make sure this runs before the `Derefer`, since it might add multiple levels
/// of dereferences in the `Operand`s that are arguments to the `Call`.
pub struct SimplifyRefComparisons;

impl<'tcx> MirPass<'tcx> for SimplifyRefComparisons {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
// Despite the method name, this is `PartialEq`, not `Eq`.
let Some(partial_eq) = tcx.lang_items().eq_trait() else { return };
let Some(partial_ord) = tcx.lang_items().partial_ord_trait() else { return };

for block in body.basic_blocks.as_mut() {
let terminator = block.terminator.as_mut().unwrap();
let TerminatorKind::Call { func, args, from_hir_call: false, .. } =
&mut terminator.kind
else { continue };

// Quickly skip unary operators
if args.len() != 2 {
continue;
}
let (Some(left_place), Some(right_place)) = (args[0].place(), args[1].place())
else { continue };

let (fn_def, fn_substs, fn_span) =
func.const_fn_def().expect("HIR operators to always call the traits directly");
let substs =
fn_substs.try_as_type_list().expect("HIR operators only have type parameters");
let [left_ty, right_ty] = *substs.as_slice() else { continue };
let (depth, new_left_ty, new_right_ty) = find_ref_depth(left_ty, right_ty);
if depth == 0 {
// Already dereffed as far as possible.
continue;
}

// Check it's a comparison, not `+`/`&`/etc.
let trait_def = tcx.trait_of_item(fn_def);
if trait_def != Some(partial_eq) && trait_def != Some(partial_ord) {
continue;
}

let derefs = vec![ProjectionElem::Deref; depth];
let new_substs = [new_left_ty.into(), new_right_ty.into()];

*func = Operand::function_handle(tcx, fn_def, new_substs, fn_span);
args[0] = Operand::Copy(left_place.project_deeper(&derefs, tcx));
args[1] = Operand::Copy(right_place.project_deeper(&derefs, tcx));
}
}
}

fn find_ref_depth<'tcx>(mut left: Ty<'tcx>, mut right: Ty<'tcx>) -> (usize, Ty<'tcx>, Ty<'tcx>) {
let mut depth = 0;
while let (ty::Ref(_, new_left, Mutability::Not), ty::Ref(_, new_right, Mutability::Not)) =
(left.kind(), right.kind())
{
depth += 1;
(left, right) = (*new_left, *new_right);
}

(depth, left, right)
}
168 changes: 168 additions & 0 deletions tests/mir-opt/simplify_cmp.multi_ref_prim.SimplifyRefComparisons.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
- // MIR for `multi_ref_prim` before SimplifyRefComparisons
+ // MIR for `multi_ref_prim` after SimplifyRefComparisons

fn multi_ref_prim(_1: &&&i32, _2: &&&i32) -> () {
debug x => _1; // in scope 0 at $DIR/simplify_cmp.rs:+0:23: +0:24
debug y => _2; // in scope 0 at $DIR/simplify_cmp.rs:+0:34: +0:35
let mut _0: (); // return place in scope 0 at $DIR/simplify_cmp.rs:+0:45: +0:45
let _3: bool; // in scope 0 at $DIR/simplify_cmp.rs:+1:9: +1:11
let mut _4: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:15
let mut _5: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20
let mut _7: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+2:14: +2:15
let mut _8: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+2:19: +2:20
let mut _10: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+3:14: +3:15
let mut _11: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+3:18: +3:19
let _12: &&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+3:18: +3:19
let mut _14: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+4:14: +4:15
let mut _15: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+4:19: +4:20
let _16: &&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+4:19: +4:20
let mut _18: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+5:14: +5:15
let mut _19: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+5:18: +5:19
let _20: &&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+5:18: +5:19
let mut _22: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+6:14: +6:15
let mut _23: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+6:19: +6:20
let _24: &&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+6:19: +6:20
scope 1 {
debug _a => _3; // in scope 1 at $DIR/simplify_cmp.rs:+1:9: +1:11
let _6: bool; // in scope 1 at $DIR/simplify_cmp.rs:+2:9: +2:11
scope 2 {
debug _b => _6; // in scope 2 at $DIR/simplify_cmp.rs:+2:9: +2:11
let _9: bool; // in scope 2 at $DIR/simplify_cmp.rs:+3:9: +3:11
scope 3 {
debug _c => _9; // in scope 3 at $DIR/simplify_cmp.rs:+3:9: +3:11
let _13: bool; // in scope 3 at $DIR/simplify_cmp.rs:+4:9: +4:11
scope 4 {
debug _d => _13; // in scope 4 at $DIR/simplify_cmp.rs:+4:9: +4:11
let _17: bool; // in scope 4 at $DIR/simplify_cmp.rs:+5:9: +5:11
scope 5 {
debug _e => _17; // in scope 5 at $DIR/simplify_cmp.rs:+5:9: +5:11
let _21: bool; // in scope 5 at $DIR/simplify_cmp.rs:+6:9: +6:11
scope 6 {
debug _f => _21; // in scope 6 at $DIR/simplify_cmp.rs:+6:9: +6:11
}
}
}
}
}
}

bb0: {
StorageLive(_3); // scope 0 at $DIR/simplify_cmp.rs:+1:9: +1:11
StorageLive(_4); // scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:15
_4 = &_1; // scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:15
StorageLive(_5); // scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20
_5 = &_2; // scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20
- _3 = <&&&i32 as PartialEq>::eq(move _4, move _5) -> bb1; // scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:20
+ _3 = <i32 as PartialEq>::eq((*(*(*_4))), (*(*(*_5)))) -> bb1; // scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:20
// mir::Constant
// + span: $DIR/simplify_cmp.rs:18:14: 18:20
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialEq>::eq}, val: Value(<ZST>) }
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialEq>::eq}, val: Value(<ZST>) }
}

bb1: {
StorageDead(_5); // scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20
StorageDead(_4); // scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20
StorageLive(_6); // scope 1 at $DIR/simplify_cmp.rs:+2:9: +2:11
StorageLive(_7); // scope 1 at $DIR/simplify_cmp.rs:+2:14: +2:15
_7 = &_1; // scope 1 at $DIR/simplify_cmp.rs:+2:14: +2:15
StorageLive(_8); // scope 1 at $DIR/simplify_cmp.rs:+2:19: +2:20
_8 = &_2; // scope 1 at $DIR/simplify_cmp.rs:+2:19: +2:20
- _6 = <&&&i32 as PartialEq>::ne(move _7, move _8) -> bb2; // scope 1 at $DIR/simplify_cmp.rs:+2:14: +2:20
+ _6 = <i32 as PartialEq>::ne((*(*(*_7))), (*(*(*_8)))) -> bb2; // scope 1 at $DIR/simplify_cmp.rs:+2:14: +2:20
// mir::Constant
// + span: $DIR/simplify_cmp.rs:19:14: 19:20
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialEq>::ne}, val: Value(<ZST>) }
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialEq>::ne}, val: Value(<ZST>) }
}

bb2: {
StorageDead(_8); // scope 1 at $DIR/simplify_cmp.rs:+2:19: +2:20
StorageDead(_7); // scope 1 at $DIR/simplify_cmp.rs:+2:19: +2:20
StorageLive(_9); // scope 2 at $DIR/simplify_cmp.rs:+3:9: +3:11
StorageLive(_10); // scope 2 at $DIR/simplify_cmp.rs:+3:14: +3:15
_10 = &_1; // scope 2 at $DIR/simplify_cmp.rs:+3:14: +3:15
StorageLive(_11); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19
StorageLive(_12); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19
_12 = &(*_2); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19
_11 = &_12; // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19
- _9 = <&&&i32 as PartialOrd>::lt(move _10, move _11) -> bb3; // scope 2 at $DIR/simplify_cmp.rs:+3:14: +3:19
+ _9 = <i32 as PartialOrd>::lt((*(*(*_10))), (*(*(*_11)))) -> bb3; // scope 2 at $DIR/simplify_cmp.rs:+3:14: +3:19
// mir::Constant
// + span: $DIR/simplify_cmp.rs:20:14: 20:19
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialOrd>::lt}, val: Value(<ZST>) }
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialOrd>::lt}, val: Value(<ZST>) }
}

bb3: {
StorageDead(_11); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19
StorageDead(_10); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19
StorageDead(_12); // scope 2 at $DIR/simplify_cmp.rs:+3:19: +3:20
StorageLive(_13); // scope 3 at $DIR/simplify_cmp.rs:+4:9: +4:11
StorageLive(_14); // scope 3 at $DIR/simplify_cmp.rs:+4:14: +4:15
_14 = &_1; // scope 3 at $DIR/simplify_cmp.rs:+4:14: +4:15
StorageLive(_15); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20
StorageLive(_16); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20
_16 = &(*_2); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20
_15 = &_16; // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20
- _13 = <&&&i32 as PartialOrd>::le(move _14, move _15) -> bb4; // scope 3 at $DIR/simplify_cmp.rs:+4:14: +4:20
+ _13 = <i32 as PartialOrd>::le((*(*(*_14))), (*(*(*_15)))) -> bb4; // scope 3 at $DIR/simplify_cmp.rs:+4:14: +4:20
// mir::Constant
// + span: $DIR/simplify_cmp.rs:21:14: 21:20
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialOrd>::le}, val: Value(<ZST>) }
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialOrd>::le}, val: Value(<ZST>) }
}

bb4: {
StorageDead(_15); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20
StorageDead(_14); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20
StorageDead(_16); // scope 3 at $DIR/simplify_cmp.rs:+4:20: +4:21
StorageLive(_17); // scope 4 at $DIR/simplify_cmp.rs:+5:9: +5:11
StorageLive(_18); // scope 4 at $DIR/simplify_cmp.rs:+5:14: +5:15
_18 = &_1; // scope 4 at $DIR/simplify_cmp.rs:+5:14: +5:15
StorageLive(_19); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19
StorageLive(_20); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19
_20 = &(*_2); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19
_19 = &_20; // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19
- _17 = <&&&i32 as PartialOrd>::gt(move _18, move _19) -> bb5; // scope 4 at $DIR/simplify_cmp.rs:+5:14: +5:19
+ _17 = <i32 as PartialOrd>::gt((*(*(*_18))), (*(*(*_19)))) -> bb5; // scope 4 at $DIR/simplify_cmp.rs:+5:14: +5:19
// mir::Constant
// + span: $DIR/simplify_cmp.rs:22:14: 22:19
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialOrd>::gt}, val: Value(<ZST>) }
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialOrd>::gt}, val: Value(<ZST>) }
}

bb5: {
StorageDead(_19); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19
StorageDead(_18); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19
StorageDead(_20); // scope 4 at $DIR/simplify_cmp.rs:+5:19: +5:20
StorageLive(_21); // scope 5 at $DIR/simplify_cmp.rs:+6:9: +6:11
StorageLive(_22); // scope 5 at $DIR/simplify_cmp.rs:+6:14: +6:15
_22 = &_1; // scope 5 at $DIR/simplify_cmp.rs:+6:14: +6:15
StorageLive(_23); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20
StorageLive(_24); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20
_24 = &(*_2); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20
_23 = &_24; // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20
- _21 = <&&&i32 as PartialOrd>::ge(move _22, move _23) -> bb6; // scope 5 at $DIR/simplify_cmp.rs:+6:14: +6:20
+ _21 = <i32 as PartialOrd>::ge((*(*(*_22))), (*(*(*_23)))) -> bb6; // scope 5 at $DIR/simplify_cmp.rs:+6:14: +6:20
// mir::Constant
// + span: $DIR/simplify_cmp.rs:23:14: 23:20
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialOrd>::ge}, val: Value(<ZST>) }
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialOrd>::ge}, val: Value(<ZST>) }
}

bb6: {
StorageDead(_23); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20
StorageDead(_22); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20
StorageDead(_24); // scope 5 at $DIR/simplify_cmp.rs:+6:20: +6:21
_0 = const (); // scope 0 at $DIR/simplify_cmp.rs:+0:45: +7:2
StorageDead(_21); // scope 5 at $DIR/simplify_cmp.rs:+7:1: +7:2
StorageDead(_17); // scope 4 at $DIR/simplify_cmp.rs:+7:1: +7:2
StorageDead(_13); // scope 3 at $DIR/simplify_cmp.rs:+7:1: +7:2
StorageDead(_9); // scope 2 at $DIR/simplify_cmp.rs:+7:1: +7:2
StorageDead(_6); // scope 1 at $DIR/simplify_cmp.rs:+7:1: +7:2
StorageDead(_3); // scope 0 at $DIR/simplify_cmp.rs:+7:1: +7:2
return; // scope 0 at $DIR/simplify_cmp.rs:+7:2: +7:2
}
}

Loading

0 comments on commit 7b74759

Please sign in to comment.