forked from rust-lang/rust
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
When HIR auto-refs a comparison operator, clean it up by dereffing in…
… MIR Today, if you're comparing `&&T`s, it ends up auto-reffing in HIR. So the MIR ends up calling `PartialEq/Cmp` with `&&&T`, and the MIR inliner can only get that down to `&T`: <https://rust.godbolt.org/z/hje6jd4Yf>. So this adds an always-run pass to look at `Call`s in MIR with `from_hir_call: false` to just call the correct `Partial{Eq,Cmp}` implementation directly, even if it's debug and we're not running the inliner, to avoid needing to ever monomorphize a bunch of useless forwarding impls. This hopes to avoid ever needing something like rust-lang#108372 where we'd tell people to manually dereference the sides of their comparisons.
- Loading branch information
Showing
11 changed files
with
910 additions
and
213 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
86 changes: 86 additions & 0 deletions
86
compiler/rustc_mir_transform/src/simplify_ref_comparisons.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
use crate::MirPass; | ||
use rustc_middle::mir::*; | ||
use rustc_middle::ty::{self, Ty, TyCtxt}; | ||
|
||
/// This pass replaces `x OP y` with `*x OP *y` when `OP` is a comparison operator. | ||
/// | ||
/// The goal is to make is so that it's never better for the user to write | ||
/// `***x == ***y` than to write the obvious `x == y` (when `x` and `y` are | ||
/// references and thus those do the same thing). This is particularly | ||
/// important because the type-checker will auto-ref any comparison that's not | ||
/// done directly on a primitive. That means that `a_ref == b_ref` doesn't | ||
/// become `PartialEq::eq(a_ref, b_ref)`, even though that would work, but rather | ||
/// ```no_run | ||
/// # fn foo(a_ref: &i32, b_ref: &i32) -> bool { | ||
/// let temp1 = &a_ref; | ||
/// let temp2 = &b_ref; | ||
/// PartialEq::eq(temp1, temp2) | ||
/// # } | ||
/// ``` | ||
/// Thus this pass means it directly calls the *interesting* `impl` directly, | ||
/// rather than needing to monomorphize and/or inline it later. (And when this | ||
/// comment was written in March 2023, the MIR inliner seemed to only inline | ||
/// one level of `==`, so if the comparison is on something like `&&i32` the | ||
/// extra forwarding impls needed to be monomorphized even in an optimized build.) | ||
/// | ||
/// Make sure this runs before the `Derefer`, since it might add multiple levels | ||
/// of dereferences in the `Operand`s that are arguments to the `Call`. | ||
pub struct SimplifyRefComparisons; | ||
|
||
impl<'tcx> MirPass<'tcx> for SimplifyRefComparisons { | ||
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { | ||
// Despite the method name, this is `PartialEq`, not `Eq`. | ||
let Some(partial_eq) = tcx.lang_items().eq_trait() else { return }; | ||
let Some(partial_ord) = tcx.lang_items().partial_ord_trait() else { return }; | ||
|
||
for block in body.basic_blocks.as_mut() { | ||
let terminator = block.terminator.as_mut().unwrap(); | ||
let TerminatorKind::Call { func, args, from_hir_call: false, .. } = | ||
&mut terminator.kind | ||
else { continue }; | ||
|
||
// Quickly skip unary operators | ||
if args.len() != 2 { | ||
continue; | ||
} | ||
let (Some(left_place), Some(right_place)) = (args[0].place(), args[1].place()) | ||
else { continue }; | ||
|
||
let (fn_def, fn_substs, fn_span) = | ||
func.const_fn_def().expect("HIR operators to always call the traits directly"); | ||
let substs = | ||
fn_substs.try_as_type_list().expect("HIR operators only have type parameters"); | ||
let [left_ty, right_ty] = *substs.as_slice() else { continue }; | ||
let (depth, new_left_ty, new_right_ty) = find_ref_depth(left_ty, right_ty); | ||
if depth == 0 { | ||
// Already dereffed as far as possible. | ||
continue; | ||
} | ||
|
||
// Check it's a comparison, not `+`/`&`/etc. | ||
let trait_def = tcx.trait_of_item(fn_def); | ||
if trait_def != Some(partial_eq) && trait_def != Some(partial_ord) { | ||
continue; | ||
} | ||
|
||
let derefs = vec![ProjectionElem::Deref; depth]; | ||
let new_substs = [new_left_ty.into(), new_right_ty.into()]; | ||
|
||
*func = Operand::function_handle(tcx, fn_def, new_substs, fn_span); | ||
args[0] = Operand::Copy(left_place.project_deeper(&derefs, tcx)); | ||
args[1] = Operand::Copy(right_place.project_deeper(&derefs, tcx)); | ||
} | ||
} | ||
} | ||
|
||
fn find_ref_depth<'tcx>(mut left: Ty<'tcx>, mut right: Ty<'tcx>) -> (usize, Ty<'tcx>, Ty<'tcx>) { | ||
let mut depth = 0; | ||
while let (ty::Ref(_, new_left, Mutability::Not), ty::Ref(_, new_right, Mutability::Not)) = | ||
(left.kind(), right.kind()) | ||
{ | ||
depth += 1; | ||
(left, right) = (*new_left, *new_right); | ||
} | ||
|
||
(depth, left, right) | ||
} |
168 changes: 168 additions & 0 deletions
168
tests/mir-opt/simplify_cmp.multi_ref_prim.SimplifyRefComparisons.diff
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
- // MIR for `multi_ref_prim` before SimplifyRefComparisons | ||
+ // MIR for `multi_ref_prim` after SimplifyRefComparisons | ||
|
||
fn multi_ref_prim(_1: &&&i32, _2: &&&i32) -> () { | ||
debug x => _1; // in scope 0 at $DIR/simplify_cmp.rs:+0:23: +0:24 | ||
debug y => _2; // in scope 0 at $DIR/simplify_cmp.rs:+0:34: +0:35 | ||
let mut _0: (); // return place in scope 0 at $DIR/simplify_cmp.rs:+0:45: +0:45 | ||
let _3: bool; // in scope 0 at $DIR/simplify_cmp.rs:+1:9: +1:11 | ||
let mut _4: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:15 | ||
let mut _5: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20 | ||
let mut _7: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+2:14: +2:15 | ||
let mut _8: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+2:19: +2:20 | ||
let mut _10: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+3:14: +3:15 | ||
let mut _11: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+3:18: +3:19 | ||
let _12: &&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+3:18: +3:19 | ||
let mut _14: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+4:14: +4:15 | ||
let mut _15: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+4:19: +4:20 | ||
let _16: &&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+4:19: +4:20 | ||
let mut _18: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+5:14: +5:15 | ||
let mut _19: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+5:18: +5:19 | ||
let _20: &&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+5:18: +5:19 | ||
let mut _22: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+6:14: +6:15 | ||
let mut _23: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+6:19: +6:20 | ||
let _24: &&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+6:19: +6:20 | ||
scope 1 { | ||
debug _a => _3; // in scope 1 at $DIR/simplify_cmp.rs:+1:9: +1:11 | ||
let _6: bool; // in scope 1 at $DIR/simplify_cmp.rs:+2:9: +2:11 | ||
scope 2 { | ||
debug _b => _6; // in scope 2 at $DIR/simplify_cmp.rs:+2:9: +2:11 | ||
let _9: bool; // in scope 2 at $DIR/simplify_cmp.rs:+3:9: +3:11 | ||
scope 3 { | ||
debug _c => _9; // in scope 3 at $DIR/simplify_cmp.rs:+3:9: +3:11 | ||
let _13: bool; // in scope 3 at $DIR/simplify_cmp.rs:+4:9: +4:11 | ||
scope 4 { | ||
debug _d => _13; // in scope 4 at $DIR/simplify_cmp.rs:+4:9: +4:11 | ||
let _17: bool; // in scope 4 at $DIR/simplify_cmp.rs:+5:9: +5:11 | ||
scope 5 { | ||
debug _e => _17; // in scope 5 at $DIR/simplify_cmp.rs:+5:9: +5:11 | ||
let _21: bool; // in scope 5 at $DIR/simplify_cmp.rs:+6:9: +6:11 | ||
scope 6 { | ||
debug _f => _21; // in scope 6 at $DIR/simplify_cmp.rs:+6:9: +6:11 | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
bb0: { | ||
StorageLive(_3); // scope 0 at $DIR/simplify_cmp.rs:+1:9: +1:11 | ||
StorageLive(_4); // scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:15 | ||
_4 = &_1; // scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:15 | ||
StorageLive(_5); // scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20 | ||
_5 = &_2; // scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20 | ||
- _3 = <&&&i32 as PartialEq>::eq(move _4, move _5) -> bb1; // scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:20 | ||
+ _3 = <i32 as PartialEq>::eq((*(*(*_4))), (*(*(*_5)))) -> bb1; // scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:20 | ||
// mir::Constant | ||
// + span: $DIR/simplify_cmp.rs:18:14: 18:20 | ||
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialEq>::eq}, val: Value(<ZST>) } | ||
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialEq>::eq}, val: Value(<ZST>) } | ||
} | ||
|
||
bb1: { | ||
StorageDead(_5); // scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20 | ||
StorageDead(_4); // scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20 | ||
StorageLive(_6); // scope 1 at $DIR/simplify_cmp.rs:+2:9: +2:11 | ||
StorageLive(_7); // scope 1 at $DIR/simplify_cmp.rs:+2:14: +2:15 | ||
_7 = &_1; // scope 1 at $DIR/simplify_cmp.rs:+2:14: +2:15 | ||
StorageLive(_8); // scope 1 at $DIR/simplify_cmp.rs:+2:19: +2:20 | ||
_8 = &_2; // scope 1 at $DIR/simplify_cmp.rs:+2:19: +2:20 | ||
- _6 = <&&&i32 as PartialEq>::ne(move _7, move _8) -> bb2; // scope 1 at $DIR/simplify_cmp.rs:+2:14: +2:20 | ||
+ _6 = <i32 as PartialEq>::ne((*(*(*_7))), (*(*(*_8)))) -> bb2; // scope 1 at $DIR/simplify_cmp.rs:+2:14: +2:20 | ||
// mir::Constant | ||
// + span: $DIR/simplify_cmp.rs:19:14: 19:20 | ||
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialEq>::ne}, val: Value(<ZST>) } | ||
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialEq>::ne}, val: Value(<ZST>) } | ||
} | ||
|
||
bb2: { | ||
StorageDead(_8); // scope 1 at $DIR/simplify_cmp.rs:+2:19: +2:20 | ||
StorageDead(_7); // scope 1 at $DIR/simplify_cmp.rs:+2:19: +2:20 | ||
StorageLive(_9); // scope 2 at $DIR/simplify_cmp.rs:+3:9: +3:11 | ||
StorageLive(_10); // scope 2 at $DIR/simplify_cmp.rs:+3:14: +3:15 | ||
_10 = &_1; // scope 2 at $DIR/simplify_cmp.rs:+3:14: +3:15 | ||
StorageLive(_11); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19 | ||
StorageLive(_12); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19 | ||
_12 = &(*_2); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19 | ||
_11 = &_12; // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19 | ||
- _9 = <&&&i32 as PartialOrd>::lt(move _10, move _11) -> bb3; // scope 2 at $DIR/simplify_cmp.rs:+3:14: +3:19 | ||
+ _9 = <i32 as PartialOrd>::lt((*(*(*_10))), (*(*(*_11)))) -> bb3; // scope 2 at $DIR/simplify_cmp.rs:+3:14: +3:19 | ||
// mir::Constant | ||
// + span: $DIR/simplify_cmp.rs:20:14: 20:19 | ||
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialOrd>::lt}, val: Value(<ZST>) } | ||
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialOrd>::lt}, val: Value(<ZST>) } | ||
} | ||
|
||
bb3: { | ||
StorageDead(_11); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19 | ||
StorageDead(_10); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19 | ||
StorageDead(_12); // scope 2 at $DIR/simplify_cmp.rs:+3:19: +3:20 | ||
StorageLive(_13); // scope 3 at $DIR/simplify_cmp.rs:+4:9: +4:11 | ||
StorageLive(_14); // scope 3 at $DIR/simplify_cmp.rs:+4:14: +4:15 | ||
_14 = &_1; // scope 3 at $DIR/simplify_cmp.rs:+4:14: +4:15 | ||
StorageLive(_15); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20 | ||
StorageLive(_16); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20 | ||
_16 = &(*_2); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20 | ||
_15 = &_16; // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20 | ||
- _13 = <&&&i32 as PartialOrd>::le(move _14, move _15) -> bb4; // scope 3 at $DIR/simplify_cmp.rs:+4:14: +4:20 | ||
+ _13 = <i32 as PartialOrd>::le((*(*(*_14))), (*(*(*_15)))) -> bb4; // scope 3 at $DIR/simplify_cmp.rs:+4:14: +4:20 | ||
// mir::Constant | ||
// + span: $DIR/simplify_cmp.rs:21:14: 21:20 | ||
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialOrd>::le}, val: Value(<ZST>) } | ||
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialOrd>::le}, val: Value(<ZST>) } | ||
} | ||
|
||
bb4: { | ||
StorageDead(_15); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20 | ||
StorageDead(_14); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20 | ||
StorageDead(_16); // scope 3 at $DIR/simplify_cmp.rs:+4:20: +4:21 | ||
StorageLive(_17); // scope 4 at $DIR/simplify_cmp.rs:+5:9: +5:11 | ||
StorageLive(_18); // scope 4 at $DIR/simplify_cmp.rs:+5:14: +5:15 | ||
_18 = &_1; // scope 4 at $DIR/simplify_cmp.rs:+5:14: +5:15 | ||
StorageLive(_19); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19 | ||
StorageLive(_20); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19 | ||
_20 = &(*_2); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19 | ||
_19 = &_20; // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19 | ||
- _17 = <&&&i32 as PartialOrd>::gt(move _18, move _19) -> bb5; // scope 4 at $DIR/simplify_cmp.rs:+5:14: +5:19 | ||
+ _17 = <i32 as PartialOrd>::gt((*(*(*_18))), (*(*(*_19)))) -> bb5; // scope 4 at $DIR/simplify_cmp.rs:+5:14: +5:19 | ||
// mir::Constant | ||
// + span: $DIR/simplify_cmp.rs:22:14: 22:19 | ||
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialOrd>::gt}, val: Value(<ZST>) } | ||
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialOrd>::gt}, val: Value(<ZST>) } | ||
} | ||
|
||
bb5: { | ||
StorageDead(_19); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19 | ||
StorageDead(_18); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19 | ||
StorageDead(_20); // scope 4 at $DIR/simplify_cmp.rs:+5:19: +5:20 | ||
StorageLive(_21); // scope 5 at $DIR/simplify_cmp.rs:+6:9: +6:11 | ||
StorageLive(_22); // scope 5 at $DIR/simplify_cmp.rs:+6:14: +6:15 | ||
_22 = &_1; // scope 5 at $DIR/simplify_cmp.rs:+6:14: +6:15 | ||
StorageLive(_23); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20 | ||
StorageLive(_24); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20 | ||
_24 = &(*_2); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20 | ||
_23 = &_24; // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20 | ||
- _21 = <&&&i32 as PartialOrd>::ge(move _22, move _23) -> bb6; // scope 5 at $DIR/simplify_cmp.rs:+6:14: +6:20 | ||
+ _21 = <i32 as PartialOrd>::ge((*(*(*_22))), (*(*(*_23)))) -> bb6; // scope 5 at $DIR/simplify_cmp.rs:+6:14: +6:20 | ||
// mir::Constant | ||
// + span: $DIR/simplify_cmp.rs:23:14: 23:20 | ||
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialOrd>::ge}, val: Value(<ZST>) } | ||
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialOrd>::ge}, val: Value(<ZST>) } | ||
} | ||
|
||
bb6: { | ||
StorageDead(_23); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20 | ||
StorageDead(_22); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20 | ||
StorageDead(_24); // scope 5 at $DIR/simplify_cmp.rs:+6:20: +6:21 | ||
_0 = const (); // scope 0 at $DIR/simplify_cmp.rs:+0:45: +7:2 | ||
StorageDead(_21); // scope 5 at $DIR/simplify_cmp.rs:+7:1: +7:2 | ||
StorageDead(_17); // scope 4 at $DIR/simplify_cmp.rs:+7:1: +7:2 | ||
StorageDead(_13); // scope 3 at $DIR/simplify_cmp.rs:+7:1: +7:2 | ||
StorageDead(_9); // scope 2 at $DIR/simplify_cmp.rs:+7:1: +7:2 | ||
StorageDead(_6); // scope 1 at $DIR/simplify_cmp.rs:+7:1: +7:2 | ||
StorageDead(_3); // scope 0 at $DIR/simplify_cmp.rs:+7:1: +7:2 | ||
return; // scope 0 at $DIR/simplify_cmp.rs:+7:2: +7:2 | ||
} | ||
} | ||
|
Oops, something went wrong.