Skip to content

Conditionally fuse small constant constant integer switches when lowering slice patterns #136417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ macro_rules! arena_types {
[decode] specialization_graph: rustc_middle::traits::specialization_graph::Graph,
[] crate_inherent_impls: rustc_middle::ty::CrateInherentImpls,
[] hir_owner_nodes: rustc_hir::OwnerNodes<'tcx>,
[] thir_pats: rustc_middle::thir::Pat<'tcx>,
]);
)
}
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_mir_build/src/builder/expr/as_place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,10 @@ impl<'tcx> PlaceBuilder<'tcx> {
&self.projection
}

pub(crate) fn projection_mut(&mut self) -> &mut [PlaceElem<'tcx>] {
&mut self.projection
}

pub(crate) fn field(self, f: FieldIdx, ty: Ty<'tcx>) -> Self {
self.project(PlaceElem::Field(f, ty))
}
Expand Down
295 changes: 279 additions & 16 deletions compiler/rustc_mir_build/src/builder/matches/match_pair.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
use std::ops;

use either::Either;
use rustc_middle::bug;
use rustc_middle::mir::*;
use rustc_middle::thir::{self, *};
use rustc_middle::ty::{self, Ty, TypeVisitableExt};

use crate::builder::Builder;
use crate::builder::expr::as_place::{PlaceBase, PlaceBuilder};
use crate::builder::matches::util::Range;
use crate::builder::matches::{FlatPat, MatchPairTree, TestCase};

impl<'a, 'tcx> Builder<'a, 'tcx> {
Expand Down Expand Up @@ -33,6 +38,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
/// Used internally by [`MatchPairTree::for_pattern`].
fn prefix_slice_suffix<'pat>(
&mut self,
top_pattern: &'pat Pat<'tcx>,
match_pairs: &mut Vec<MatchPairTree<'pat, 'tcx>>,
place: &PlaceBuilder<'tcx>,
prefix: &'pat [Box<Pat<'tcx>>],
Expand All @@ -54,11 +60,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
((prefix.len() + suffix.len()).try_into().unwrap(), false)
};

match_pairs.extend(prefix.iter().enumerate().map(|(idx, subpattern)| {
let elem =
ProjectionElem::ConstantIndex { offset: idx as u64, min_length, from_end: false };
MatchPairTree::for_pattern(place.clone_project(elem), subpattern, self)
}));
if !prefix.is_empty() {
let bounds = Range::from_start(0..prefix.len() as u64);
let subpattern = bounds.apply(prefix);
self.build_slice_branch(bounds, place, top_pattern, subpattern, min_length)
.for_each(|pair| match_pairs.push(pair));
}

if let Some(subslice_pat) = opt_slice {
let suffix_len = suffix.len() as u64;
Expand All @@ -70,16 +77,258 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
match_pairs.push(MatchPairTree::for_pattern(subslice, subslice_pat, self));
}

match_pairs.extend(suffix.iter().rev().enumerate().map(|(idx, subpattern)| {
let end_offset = (idx + 1) as u64;
let elem = ProjectionElem::ConstantIndex {
offset: if exact_size { min_length - end_offset } else { end_offset },
min_length,
from_end: !exact_size,
if !suffix.is_empty() {
let bounds = Range::from_end(0..suffix.len() as u64);
let subpattern = bounds.apply(suffix);
self.build_slice_branch(bounds, place, top_pattern, subpattern, min_length)
.for_each(|pair| match_pairs.push(pair));
}
}

// Traverses either side of a slice pattern (prefix/suffix) and yields an iterator of `MatchPairTree`s
// to cover all it's constant and non-constant subpatterns.
fn build_slice_branch<'pat, 'b>(
&'b mut self,
bounds: Range,
place: &'b PlaceBuilder<'tcx>,
top_pattern: &'pat Pat<'tcx>,
pattern: &'pat [Box<Pat<'tcx>>],
min_length: u64,
) -> impl Iterator<Item = MatchPairTree<'pat, 'tcx>> + use<'a, 'tcx, 'pat, 'b> {
let entries = self.find_const_groups(pattern);

entries.into_iter().map(move |entry| {
// Common case handler for both non-constant and constant subpatterns not in a range.
let mut build_single = |idx| {
let subpattern = &pattern[idx as usize];
let place = place.clone_project(ProjectionElem::ConstantIndex {
offset: bounds.shift_idx(idx),
min_length: pattern.len() as u64,
from_end: bounds.from_end,
});

MatchPairTree::for_pattern(place, subpattern, self)
};
let place = place.clone_project(elem);
MatchPairTree::for_pattern(place, subpattern, self)
}));

match entry {
Either::Right(range) if range.end - range.start > 1 => {
// Figure out which subslice of our already sliced pattern we're looking at.
let subpattern = &pattern[range.start as usize..range.end as usize];
let elem_ty = subpattern[0].ty;

// Right, we 've found a group of constant patterns worth grouping for later.
// We'll collect all the leaves we can find and create a single `ValTree` out of them.
let valtree = self.simplify_const_pattern_slice_into_valtree(subpattern);
self.valtree_to_match_pair(
top_pattern,
valtree,
place.clone(),
elem_ty,
bounds.shift_range(range),
min_length,
)
}
Either::Right(range) => build_single(range.start),
Either::Left(idx) => build_single(idx),
}
})
}

// Given a partial view of the elements in a slice pattern, returns a list
// with left denoting non-constant element indices and right denoting ranges of constant elements.
fn find_const_groups(&self, pattern: &[Box<Pat<'tcx>>]) -> Vec<Either<u64, ops::Range<u64>>> {
let mut entries = Vec::new();
let mut current_seq_start = None;

for (idx, pat) in pattern.iter().enumerate() {
if self.is_constant_pattern(pat) {
if current_seq_start.is_none() {
current_seq_start = Some(idx as u64);
} else {
continue;
}
} else {
if let Some(start) = current_seq_start {
entries.push(Either::Right(start..idx as u64));
current_seq_start = None;
}
entries.push(Either::Left(idx as u64));
}
}

if let Some(start) = current_seq_start {
entries.push(Either::Right(start..pattern.len() as u64));
}

entries
}

// Checks if a pattern is constant and represented by a single scalar leaf.
fn is_constant_pattern(&self, pat: &Pat<'tcx>) -> bool {
if let PatKind::Constant { value } = pat.kind
&& let Const::Ty(_, const_) = value
&& let ty::ConstKind::Value(cv) = const_.kind()
&& let ty::ValTree::Leaf(_) = cv.valtree
{
true
} else {
false
}
}

// Extract the `ValTree` from a constant pattern.
// You must ensure that the pattern is a constant pattern before calling this function or it will panic.
fn extract_leaf(&self, pat: &Pat<'tcx>) -> ty::ValTree<'tcx> {
if let PatKind::Constant { value } = pat.kind
&& let Const::Ty(_, const_) = value
&& let ty::ConstKind::Value(cv) = const_.kind()
&& matches!(cv.valtree, ty::ValTree::Leaf(_))
{
cv.valtree
} else {
bug!("expected constant pattern, got {:?}", pat)
}
}

// Simplifies a slice of constant patterns into a single flattened `ValTree`.
fn simplify_const_pattern_slice_into_valtree(
&self,
subslice: &[Box<Pat<'tcx>>],
) -> ty::ValTree<'tcx> {
let leaves = subslice.iter().map(|p| self.extract_leaf(p));
let interned = self.tcx.arena.alloc_from_iter(leaves);
ty::ValTree::Branch(interned)
}

// Given a `ValTree` representing a slice of constant patterns, returns a `MatchPairTree`
// representing the slice pattern, providing as much info about subsequences in the slice as possible
// to later lowering stages.
fn valtree_to_match_pair<'pat>(
&mut self,
source_pattern: &'pat Pat<'tcx>,
valtree: ty::ValTree<'tcx>,
place: PlaceBuilder<'tcx>,
elem_ty: Ty<'tcx>,
range: Range,
min_length: u64,
) -> MatchPairTree<'pat, 'tcx> {
let tcx = self.tcx;
let leaves = match valtree {
ty::ValTree::Leaf(_) => bug!("expected branch, got leaf"),
ty::ValTree::Branch(leaves) => leaves,
};

assert!(range.len() == leaves.len() as u64);
let mut subpairs = Vec::new();
let mut were_merged = 0;

if elem_ty == tcx.types.u8 {
let leaf_bits = |leaf: ty::ValTree<'tcx>| match leaf {
ty::ValTree::Leaf(scalar) => scalar.to_u8(),
_ => bug!("found unflatted valtree"),
};

let mut fuse_group = |first_idx, len| {
were_merged += len;

let data = leaves[first_idx..first_idx + len]
.iter()
.rev()
.copied()
.map(leaf_bits)
.fold(0u32, |acc, x| (acc << 8) | u32::from(x));

let fused_ty = match len {
2 => tcx.types.u16,
3 | 4 => tcx.types.u32,
_ => unreachable!(),
};

let scalar = match len {
2 => ty::ScalarInt::from(data as u16),
3 | 4 => ty::ScalarInt::from(data),
_ => unreachable!(),
};

let valtree = ty::ValTree::Leaf(scalar);
let ty_const =
ty::Const::new(tcx, ty::ConstKind::Value(ty::Value { ty: fused_ty, valtree }));

let value = Const::Ty(fused_ty, ty_const);
let test_case = TestCase::FusedConstant { value, fused: len as u64 };

let pattern = tcx.arena.alloc(Pat {
ty: fused_ty,
span: source_pattern.span,
kind: PatKind::Constant { value },
});

let place = place
.clone_project(ProjectionElem::ConstantIndex {
offset: range.shift_idx(first_idx as u64),
min_length,
from_end: range.from_end,
})
.to_place(self);

subpairs.push(MatchPairTree {
place: Some(place),
test_case,
subpairs: Vec::new(),
pattern,
});
};

let indices = |group_size, skip| {
(skip..usize::MAX)
.take_while(move |i| i * group_size + (group_size - 1) < leaves.len())
};

let mut skip = 0;
for i in (2..=4).rev() {
for idx in indices(i, skip) {
fuse_group(idx * i, i);
skip += i;
}
}
}

for (idx, leaf) in leaves.iter().enumerate().skip(were_merged) {
let ty_const = ty::Const::new(
tcx,
ty::ConstKind::Value(ty::Value { ty: elem_ty, valtree: *leaf }),
);
let value = Const::Ty(elem_ty, ty_const);
let test_case = TestCase::Constant { value };

let pattern = tcx.arena.alloc(Pat {
ty: elem_ty,
span: source_pattern.span,
kind: PatKind::Constant { value },
});

let place = place
.clone_project(ProjectionElem::ConstantIndex {
offset: range.start + idx as u64,
min_length,
from_end: range.from_end,
})
.to_place(self);

subpairs.push(MatchPairTree {
place: Some(place),
test_case,
subpairs: Vec::new(),
pattern,
});
}

MatchPairTree {
place: None,
test_case: TestCase::Irrefutable { binding: None, ascription: None },
subpairs,
pattern: source_pattern,
}
}
}

Expand Down Expand Up @@ -192,11 +441,25 @@ impl<'pat, 'tcx> MatchPairTree<'pat, 'tcx> {
}

PatKind::Array { ref prefix, ref slice, ref suffix } => {
cx.prefix_slice_suffix(&mut subpairs, &place_builder, prefix, slice, suffix);
cx.prefix_slice_suffix(
pattern,
&mut subpairs,
&place_builder,
prefix,
slice,
suffix,
);
default_irrefutable()
}
PatKind::Slice { ref prefix, ref slice, ref suffix } => {
cx.prefix_slice_suffix(&mut subpairs, &place_builder, prefix, slice, suffix);
cx.prefix_slice_suffix(
pattern,
&mut subpairs,
&place_builder,
prefix,
slice,
suffix,
);

if prefix.is_empty() && slice.is_some() && suffix.is_empty() {
default_irrefutable()
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_mir_build/src/builder/matches/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,7 @@ enum TestCase<'pat, 'tcx> {
Irrefutable { binding: Option<Binding<'tcx>>, ascription: Option<Ascription<'tcx>> },
Variant { adt_def: ty::AdtDef<'tcx>, variant_index: VariantIdx },
Constant { value: mir::Const<'tcx> },
FusedConstant { value: mir::Const<'tcx>, fused: u64 },
Range(&'pat PatRange<'tcx>),
Slice { len: usize, variable_length: bool },
Deref { temp: Place<'tcx>, mutability: Mutability },
Expand Down Expand Up @@ -1304,7 +1305,7 @@ enum TestKind<'tcx> {
///
/// The test's target values are not stored here; instead they are extracted
/// from the [`TestCase`]s of the candidates participating in the test.
SwitchInt,
SwitchInt { fused: u64 },

/// Test whether a `bool` is `true` or `false`.
If,
Expand Down
Loading
Loading