diff --git a/compiler/rustc_middle/src/arena.rs b/compiler/rustc_middle/src/arena.rs index 750531b638e4d..e52d6fc60eba6 100644 --- a/compiler/rustc_middle/src/arena.rs +++ b/compiler/rustc_middle/src/arena.rs @@ -114,6 +114,7 @@ macro_rules! arena_types { [decode] specialization_graph: rustc_middle::traits::specialization_graph::Graph, [] crate_inherent_impls: rustc_middle::ty::CrateInherentImpls, [] hir_owner_nodes: rustc_hir::OwnerNodes<'tcx>, + [] thir_pats: rustc_middle::thir::Pat<'tcx>, ]); ) } diff --git a/compiler/rustc_mir_build/src/builder/expr/as_place.rs b/compiler/rustc_mir_build/src/builder/expr/as_place.rs index 0086775e9f46d..f12b8a8cd15e4 100644 --- a/compiler/rustc_mir_build/src/builder/expr/as_place.rs +++ b/compiler/rustc_mir_build/src/builder/expr/as_place.rs @@ -301,6 +301,10 @@ impl<'tcx> PlaceBuilder<'tcx> { &self.projection } + pub(crate) fn projection_mut(&mut self) -> &mut [PlaceElem<'tcx>] { + &mut self.projection + } + pub(crate) fn field(self, f: FieldIdx, ty: Ty<'tcx>) -> Self { self.project(PlaceElem::Field(f, ty)) } diff --git a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs index 9d59ffc88ba23..5e72338afcb24 100644 --- a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs +++ b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs @@ -1,9 +1,14 @@ +use std::ops; + +use either::Either; +use rustc_middle::bug; use rustc_middle::mir::*; use rustc_middle::thir::{self, *}; use rustc_middle::ty::{self, Ty, TypeVisitableExt}; use crate::builder::Builder; use crate::builder::expr::as_place::{PlaceBase, PlaceBuilder}; +use crate::builder::matches::util::Range; use crate::builder::matches::{FlatPat, MatchPairTree, TestCase}; impl<'a, 'tcx> Builder<'a, 'tcx> { @@ -33,6 +38,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { /// Used internally by [`MatchPairTree::for_pattern`]. fn prefix_slice_suffix<'pat>( &mut self, + top_pattern: &'pat Pat<'tcx>, match_pairs: &mut Vec>, place: &PlaceBuilder<'tcx>, prefix: &'pat [Box>], @@ -54,11 +60,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ((prefix.len() + suffix.len()).try_into().unwrap(), false) }; - match_pairs.extend(prefix.iter().enumerate().map(|(idx, subpattern)| { - let elem = - ProjectionElem::ConstantIndex { offset: idx as u64, min_length, from_end: false }; - MatchPairTree::for_pattern(place.clone_project(elem), subpattern, self) - })); + if !prefix.is_empty() { + let bounds = Range::from_start(0..prefix.len() as u64); + let subpattern = bounds.apply(prefix); + self.build_slice_branch(bounds, place, top_pattern, subpattern, min_length) + .for_each(|pair| match_pairs.push(pair)); + } if let Some(subslice_pat) = opt_slice { let suffix_len = suffix.len() as u64; @@ -70,16 +77,258 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { match_pairs.push(MatchPairTree::for_pattern(subslice, subslice_pat, self)); } - match_pairs.extend(suffix.iter().rev().enumerate().map(|(idx, subpattern)| { - let end_offset = (idx + 1) as u64; - let elem = ProjectionElem::ConstantIndex { - offset: if exact_size { min_length - end_offset } else { end_offset }, - min_length, - from_end: !exact_size, + if !suffix.is_empty() { + let bounds = Range::from_end(0..suffix.len() as u64); + let subpattern = bounds.apply(suffix); + self.build_slice_branch(bounds, place, top_pattern, subpattern, min_length) + .for_each(|pair| match_pairs.push(pair)); + } + } + + // Traverses either side of a slice pattern (prefix/suffix) and yields an iterator of `MatchPairTree`s + // to cover all it's constant and non-constant subpatterns. + fn build_slice_branch<'pat, 'b>( + &'b mut self, + bounds: Range, + place: &'b PlaceBuilder<'tcx>, + top_pattern: &'pat Pat<'tcx>, + pattern: &'pat [Box>], + min_length: u64, + ) -> impl Iterator> + use<'a, 'tcx, 'pat, 'b> { + let entries = self.find_const_groups(pattern); + + entries.into_iter().map(move |entry| { + // Common case handler for both non-constant and constant subpatterns not in a range. + let mut build_single = |idx| { + let subpattern = &pattern[idx as usize]; + let place = place.clone_project(ProjectionElem::ConstantIndex { + offset: bounds.shift_idx(idx), + min_length: pattern.len() as u64, + from_end: bounds.from_end, + }); + + MatchPairTree::for_pattern(place, subpattern, self) }; - let place = place.clone_project(elem); - MatchPairTree::for_pattern(place, subpattern, self) - })); + + match entry { + Either::Right(range) if range.end - range.start > 1 => { + // Figure out which subslice of our already sliced pattern we're looking at. + let subpattern = &pattern[range.start as usize..range.end as usize]; + let elem_ty = subpattern[0].ty; + + // Right, we 've found a group of constant patterns worth grouping for later. + // We'll collect all the leaves we can find and create a single `ValTree` out of them. + let valtree = self.simplify_const_pattern_slice_into_valtree(subpattern); + self.valtree_to_match_pair( + top_pattern, + valtree, + place.clone(), + elem_ty, + bounds.shift_range(range), + min_length, + ) + } + Either::Right(range) => build_single(range.start), + Either::Left(idx) => build_single(idx), + } + }) + } + + // Given a partial view of the elements in a slice pattern, returns a list + // with left denoting non-constant element indices and right denoting ranges of constant elements. + fn find_const_groups(&self, pattern: &[Box>]) -> Vec>> { + let mut entries = Vec::new(); + let mut current_seq_start = None; + + for (idx, pat) in pattern.iter().enumerate() { + if self.is_constant_pattern(pat) { + if current_seq_start.is_none() { + current_seq_start = Some(idx as u64); + } else { + continue; + } + } else { + if let Some(start) = current_seq_start { + entries.push(Either::Right(start..idx as u64)); + current_seq_start = None; + } + entries.push(Either::Left(idx as u64)); + } + } + + if let Some(start) = current_seq_start { + entries.push(Either::Right(start..pattern.len() as u64)); + } + + entries + } + + // Checks if a pattern is constant and represented by a single scalar leaf. + fn is_constant_pattern(&self, pat: &Pat<'tcx>) -> bool { + if let PatKind::Constant { value } = pat.kind + && let Const::Ty(_, const_) = value + && let ty::ConstKind::Value(cv) = const_.kind() + && let ty::ValTree::Leaf(_) = cv.valtree + { + true + } else { + false + } + } + + // Extract the `ValTree` from a constant pattern. + // You must ensure that the pattern is a constant pattern before calling this function or it will panic. + fn extract_leaf(&self, pat: &Pat<'tcx>) -> ty::ValTree<'tcx> { + if let PatKind::Constant { value } = pat.kind + && let Const::Ty(_, const_) = value + && let ty::ConstKind::Value(cv) = const_.kind() + && matches!(cv.valtree, ty::ValTree::Leaf(_)) + { + cv.valtree + } else { + bug!("expected constant pattern, got {:?}", pat) + } + } + + // Simplifies a slice of constant patterns into a single flattened `ValTree`. + fn simplify_const_pattern_slice_into_valtree( + &self, + subslice: &[Box>], + ) -> ty::ValTree<'tcx> { + let leaves = subslice.iter().map(|p| self.extract_leaf(p)); + let interned = self.tcx.arena.alloc_from_iter(leaves); + ty::ValTree::Branch(interned) + } + + // Given a `ValTree` representing a slice of constant patterns, returns a `MatchPairTree` + // representing the slice pattern, providing as much info about subsequences in the slice as possible + // to later lowering stages. + fn valtree_to_match_pair<'pat>( + &mut self, + source_pattern: &'pat Pat<'tcx>, + valtree: ty::ValTree<'tcx>, + place: PlaceBuilder<'tcx>, + elem_ty: Ty<'tcx>, + range: Range, + min_length: u64, + ) -> MatchPairTree<'pat, 'tcx> { + let tcx = self.tcx; + let leaves = match valtree { + ty::ValTree::Leaf(_) => bug!("expected branch, got leaf"), + ty::ValTree::Branch(leaves) => leaves, + }; + + assert!(range.len() == leaves.len() as u64); + let mut subpairs = Vec::new(); + let mut were_merged = 0; + + if elem_ty == tcx.types.u8 { + let leaf_bits = |leaf: ty::ValTree<'tcx>| match leaf { + ty::ValTree::Leaf(scalar) => scalar.to_u8(), + _ => bug!("found unflatted valtree"), + }; + + let mut fuse_group = |first_idx, len| { + were_merged += len; + + let data = leaves[first_idx..first_idx + len] + .iter() + .rev() + .copied() + .map(leaf_bits) + .fold(0u32, |acc, x| (acc << 8) | u32::from(x)); + + let fused_ty = match len { + 2 => tcx.types.u16, + 3 | 4 => tcx.types.u32, + _ => unreachable!(), + }; + + let scalar = match len { + 2 => ty::ScalarInt::from(data as u16), + 3 | 4 => ty::ScalarInt::from(data), + _ => unreachable!(), + }; + + let valtree = ty::ValTree::Leaf(scalar); + let ty_const = + ty::Const::new(tcx, ty::ConstKind::Value(ty::Value { ty: fused_ty, valtree })); + + let value = Const::Ty(fused_ty, ty_const); + let test_case = TestCase::FusedConstant { value, fused: len as u64 }; + + let pattern = tcx.arena.alloc(Pat { + ty: fused_ty, + span: source_pattern.span, + kind: PatKind::Constant { value }, + }); + + let place = place + .clone_project(ProjectionElem::ConstantIndex { + offset: range.shift_idx(first_idx as u64), + min_length, + from_end: range.from_end, + }) + .to_place(self); + + subpairs.push(MatchPairTree { + place: Some(place), + test_case, + subpairs: Vec::new(), + pattern, + }); + }; + + let indices = |group_size, skip| { + (skip..usize::MAX) + .take_while(move |i| i * group_size + (group_size - 1) < leaves.len()) + }; + + let mut skip = 0; + for i in (2..=4).rev() { + for idx in indices(i, skip) { + fuse_group(idx * i, i); + skip += i; + } + } + } + + for (idx, leaf) in leaves.iter().enumerate().skip(were_merged) { + let ty_const = ty::Const::new( + tcx, + ty::ConstKind::Value(ty::Value { ty: elem_ty, valtree: *leaf }), + ); + let value = Const::Ty(elem_ty, ty_const); + let test_case = TestCase::Constant { value }; + + let pattern = tcx.arena.alloc(Pat { + ty: elem_ty, + span: source_pattern.span, + kind: PatKind::Constant { value }, + }); + + let place = place + .clone_project(ProjectionElem::ConstantIndex { + offset: range.start + idx as u64, + min_length, + from_end: range.from_end, + }) + .to_place(self); + + subpairs.push(MatchPairTree { + place: Some(place), + test_case, + subpairs: Vec::new(), + pattern, + }); + } + + MatchPairTree { + place: None, + test_case: TestCase::Irrefutable { binding: None, ascription: None }, + subpairs, + pattern: source_pattern, + } } } @@ -192,11 +441,25 @@ impl<'pat, 'tcx> MatchPairTree<'pat, 'tcx> { } PatKind::Array { ref prefix, ref slice, ref suffix } => { - cx.prefix_slice_suffix(&mut subpairs, &place_builder, prefix, slice, suffix); + cx.prefix_slice_suffix( + pattern, + &mut subpairs, + &place_builder, + prefix, + slice, + suffix, + ); default_irrefutable() } PatKind::Slice { ref prefix, ref slice, ref suffix } => { - cx.prefix_slice_suffix(&mut subpairs, &place_builder, prefix, slice, suffix); + cx.prefix_slice_suffix( + pattern, + &mut subpairs, + &place_builder, + prefix, + slice, + suffix, + ); if prefix.is_empty() && slice.is_some() && suffix.is_empty() { default_irrefutable() diff --git a/compiler/rustc_mir_build/src/builder/matches/mod.rs b/compiler/rustc_mir_build/src/builder/matches/mod.rs index b21ec8f3083b3..1da96c7e0dacf 100644 --- a/compiler/rustc_mir_build/src/builder/matches/mod.rs +++ b/compiler/rustc_mir_build/src/builder/matches/mod.rs @@ -1238,6 +1238,7 @@ enum TestCase<'pat, 'tcx> { Irrefutable { binding: Option>, ascription: Option> }, Variant { adt_def: ty::AdtDef<'tcx>, variant_index: VariantIdx }, Constant { value: mir::Const<'tcx> }, + FusedConstant { value: mir::Const<'tcx>, fused: u64 }, Range(&'pat PatRange<'tcx>), Slice { len: usize, variable_length: bool }, Deref { temp: Place<'tcx>, mutability: Mutability }, @@ -1304,7 +1305,7 @@ enum TestKind<'tcx> { /// /// The test's target values are not stored here; instead they are extracted /// from the [`TestCase`]s of the candidates participating in the test. - SwitchInt, + SwitchInt { fused: u64 }, /// Test whether a `bool` is `true` or `false`. If, diff --git a/compiler/rustc_mir_build/src/builder/matches/test.rs b/compiler/rustc_mir_build/src/builder/matches/test.rs index afe6b4475be3c..83544fa77c288 100644 --- a/compiler/rustc_mir_build/src/builder/matches/test.rs +++ b/compiler/rustc_mir_build/src/builder/matches/test.rs @@ -19,8 +19,8 @@ use rustc_span::source_map::Spanned; use rustc_span::{DUMMY_SP, Span, Symbol, sym}; use tracing::{debug, instrument}; -use crate::builder::Builder; use crate::builder::matches::{Candidate, MatchPairTree, Test, TestBranch, TestCase, TestKind}; +use crate::builder::{Builder, PlaceBuilder}; impl<'a, 'tcx> Builder<'a, 'tcx> { /// Identifies what test is needed to decide if `match_pair` is applicable. @@ -34,9 +34,13 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { TestCase::Variant { adt_def, variant_index: _ } => TestKind::Switch { adt_def }, TestCase::Constant { .. } if match_pair.pattern.ty.is_bool() => TestKind::If, - TestCase::Constant { .. } if is_switch_ty(match_pair.pattern.ty) => TestKind::SwitchInt, + TestCase::Constant { .. } if is_switch_ty(match_pair.pattern.ty) => { + TestKind::SwitchInt { fused: 1 } + } TestCase::Constant { value } => TestKind::Eq { value, ty: match_pair.pattern.ty }, + TestCase::FusedConstant { fused, .. } => TestKind::SwitchInt { fused }, + TestCase::Range(range) => { assert_eq!(range.ty, match_pair.pattern.ty); TestKind::Range(Box::new(range.clone())) @@ -113,7 +117,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ); } - TestKind::SwitchInt => { + TestKind::SwitchInt { fused } => { // The switch may be inexhaustive so we have a catch-all block let otherwise_block = target_block(TestBranch::Failure); let switch_targets = SwitchTargets::new( @@ -126,10 +130,16 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { }), otherwise_block, ); - let terminator = TerminatorKind::SwitchInt { - discr: Operand::Copy(place), - targets: switch_targets, + + let discr = match fused { + 0 => span_bug!(test.span, "there must be at least one constant"), + 1 => Operand::Copy(place), + 2.. => { + self.fuse_switch_discriminant(block, place, place_ty.ty, fused, test.span) + } }; + + let terminator = TerminatorKind::SwitchInt { discr, targets: switch_targets }; self.cfg.terminate(block, self.source_info(match_start_span), terminator); } @@ -337,6 +347,91 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { }); } + /// "Fuse" multiple small integer constants in a sequence into a single integer, possibly + /// removing unecessary branches from the lowered match tree. + fn fuse_switch_discriminant( + &mut self, + block: BasicBlock, + place: Place<'tcx>, + elem_ty: Ty<'tcx>, + count: u64, + test_span: Span, + ) -> Operand<'tcx> { + let tcx = self.tcx; + let source_info = self.source_info(test_span); + match (count, elem_ty) { + (2..=4, ty) if ty == tcx.types.u8 || ty == tcx.types.i8 => (), + (2..=2, ty) if ty == tcx.types.u16 || ty == tcx.types.i16 => (), + (fused, ty) => span_bug!( + test_span, + "unsupported constant fusion combination of count {} and type {}", + ty, + fused + ), + }; + + let fused_ty = match count * elem_ty.primitive_size(tcx).bits() { + 8..=16 => tcx.types.u16, + ..=32 => tcx.types.u32, + _ => unreachable!(), + }; + + let builder = PlaceBuilder::from(place); + let place_for = move |b: &mut Self, idx| { + let mut builder = builder.clone(); + match builder.projection_mut() { + [.., ProjectionElem::ConstantIndex { offset, ref from_end, .. }] => { + if !from_end { + *offset += idx; + } else { + *offset -= idx; + } + } + _ => span_bug!(test_span, "found unexpected projections"), + } + builder.to_place(b) + }; + + let temp = self.temp(fused_ty, DUMMY_SP); + let acc = self.temp(fused_ty, DUMMY_SP); + + // Since we can freely cast up integers + the required shift is zero on the first + // iteration, we skip both the shift and OR operations the first time. + self.cfg.push_assign( + block, + source_info, + acc, + Rvalue::Cast(CastKind::IntToInt, Operand::Copy(place), fused_ty), + ); + + // Handle all but the first iterations, iteratively building up the fused integer. + for i in 1..count { + let place = place_for(self, i); + let shift = self.literal_operand(DUMMY_SP, Const::from_usize(tcx, i * 8)); + + self.cfg.push_assign( + block, + source_info, + temp, + Rvalue::Cast(CastKind::IntToInt, Operand::Copy(place), fused_ty), + ); + self.cfg.push_assign( + block, + source_info, + temp, + Rvalue::BinaryOp(BinOp::Shl, Box::new((Operand::Copy(temp), shift))), + ); + self.cfg.push_assign( + block, + source_info, + acc, + Rvalue::BinaryOp(BinOp::BitOr, Box::new((Operand::Copy(acc), Operand::Copy(temp)))), + ); + } + + Operand::Copy(acc) + } + /// Compare using the provided built-in comparison operator fn compare( &mut self, @@ -557,9 +652,10 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // // FIXME(#29623) we could use PatKind::Range to rule // things out here, in some cases. - (TestKind::SwitchInt, &TestCase::Constant { value }) - if is_switch_ty(match_pair.pattern.ty) => - { + ( + TestKind::SwitchInt { .. }, + &TestCase::Constant { value } | &TestCase::FusedConstant { value, .. }, + ) if is_switch_ty(match_pair.pattern.ty) => { // An important invariant of candidate sorting is that a candidate // must not match in multiple branches. For `SwitchInt` tests, adding // a new value might invalidate that property for range patterns that @@ -591,7 +687,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { Some(TestBranch::Constant(value, bits)) } } - (TestKind::SwitchInt, TestCase::Range(range)) => { + (TestKind::SwitchInt { fused: _fused }, TestCase::Range(range)) => { // When performing a `SwitchInt` test, a range pattern can be // sorted into the failure arm if it doesn't contain _any_ of // the values being tested. (This restricts what values can be diff --git a/compiler/rustc_mir_build/src/builder/matches/util.rs b/compiler/rustc_mir_build/src/builder/matches/util.rs index 1bd399e511b39..6be0ea5aa11dd 100644 --- a/compiler/rustc_mir_build/src/builder/matches/util.rs +++ b/compiler/rustc_mir_build/src/builder/matches/util.rs @@ -1,3 +1,5 @@ +use std::ops; + use rustc_data_structures::fx::FxIndexMap; use rustc_middle::mir::*; use rustc_middle::ty::Ty; @@ -229,3 +231,46 @@ pub(crate) fn ref_pat_borrow_kind(ref_mutability: Mutability) -> BorrowKind { Mutability::Not => BorrowKind::Shared, } } + +#[derive(Copy, Clone, PartialEq, Debug)] +pub(super) struct Range { + pub(super) start: u64, + pub(super) end: u64, + pub(super) from_end: bool, +} + +impl Range { + pub(super) fn from_start(range: ops::Range) -> Self { + Range { start: range.start, end: range.end, from_end: false } + } + + pub(super) fn from_end(range: ops::Range) -> Self { + Range { start: range.end, end: range.start, from_end: true } + } + + pub(super) fn len(self) -> u64 { + if !self.from_end { self.end - self.start } else { self.start - self.end } + } + + pub(super) fn apply(self, slice: &[T]) -> &[T] { + if !self.from_end { + &slice[self.start as usize..self.end as usize] + } else { + &slice[..self.start as usize - self.end as usize] + } + } + + pub(super) fn shift_idx(self, idx: u64) -> u64 { + if !self.from_end { self.start + idx } else { self.start - idx } + } + + pub(super) fn shift_range(self, range_within: ops::Range) -> Self { + if !self.from_end { + Self::from_start(self.start + range_within.start..self.start + range_within.end) + } else { + let range_within_start = range_within.end; + let range_within_end = range_within.start; + Self::from_end(self.start - range_within_start..self.start - range_within_end) + } + } +}