Skip to content

Commit 7c98f6f

Browse files
committed
Add fast path for match checking
1 parent 4e8a8b4 commit 7c98f6f

File tree

1 file changed

+95
-13
lines changed
  • compiler/rustc_mir_build/src/thir/pattern

1 file changed

+95
-13
lines changed

compiler/rustc_mir_build/src/thir/pattern/_match.rs

+95-13
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ use self::Usefulness::*;
276276
use self::WitnessPreference::*;
277277

278278
use rustc_data_structures::captures::Captures;
279-
use rustc_data_structures::fx::FxHashSet;
279+
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
280280
use rustc_index::vec::Idx;
281281

282282
use super::{compare_const_vals, PatternFoldable, PatternFolder};
@@ -504,13 +504,27 @@ impl<'p, 'tcx> FromIterator<&'p Pat<'tcx>> for PatStack<'p, 'tcx> {
504504
}
505505
}
506506

507+
/// Depending on the match patterns, the specialization process might be able to use a fast path.
508+
/// Tracks whether we can use the fast path and the lookup table needed in those cases.
509+
#[derive(Clone, Debug)]
510+
enum SpecializationCache {
511+
/// Patterns consist of only enum variants.
512+
Variants { lookup: FxHashMap<DefId, SmallVec<[usize; 1]>>, wilds: SmallVec<[usize; 1]> },
513+
/// Does not belong to the cases above, use the slow path.
514+
Incompatible,
515+
}
516+
507517
/// A 2D matrix.
508518
#[derive(Clone)]
509-
crate struct Matrix<'p, 'tcx>(Vec<PatStack<'p, 'tcx>>);
519+
crate struct Matrix<'p, 'tcx> {
520+
patterns: Vec<PatStack<'p, 'tcx>>,
521+
cache: SpecializationCache,
522+
}
510523

511524
impl<'p, 'tcx> Matrix<'p, 'tcx> {
512525
crate fn empty() -> Self {
513-
Matrix(vec![])
526+
// Use SpecializationCache::Incompatible as a placeholder; the initialization is in push().
527+
Matrix { patterns: vec![], cache: SpecializationCache::Incompatible }
514528
}
515529

516530
/// Pushes a new row to the matrix. If the row starts with an or-pattern, this expands it.
@@ -522,18 +536,65 @@ impl<'p, 'tcx> Matrix<'p, 'tcx> {
522536
self.push(row)
523537
}
524538
} else {
525-
self.0.push(row);
539+
if self.patterns.is_empty() {
540+
self.cache = if row.is_empty() {
541+
SpecializationCache::Incompatible
542+
} else {
543+
match *row.head().kind {
544+
PatKind::Variant { .. } => SpecializationCache::Variants {
545+
lookup: FxHashMap::default(),
546+
wilds: SmallVec::new(),
547+
},
548+
// Note: If the first pattern is a wildcard, then all patterns after that is not
549+
// useful. The check is simple enough so we treat it as the same as unsupported
550+
// patterns.
551+
_ => SpecializationCache::Incompatible,
552+
}
553+
};
554+
}
555+
let idx_to_insert = self.patterns.len();
556+
match &mut self.cache {
557+
SpecializationCache::Variants { ref mut lookup, ref mut wilds } => {
558+
let head = row.head();
559+
match *head.kind {
560+
_ if head.is_wildcard() => {
561+
for (_, v) in lookup.iter_mut() {
562+
v.push(idx_to_insert);
563+
}
564+
wilds.push(idx_to_insert);
565+
}
566+
PatKind::Variant { adt_def, variant_index, .. } => {
567+
lookup
568+
.entry(adt_def.variants[variant_index].def_id)
569+
.or_insert_with(|| wilds.clone())
570+
.push(idx_to_insert);
571+
}
572+
_ => {
573+
self.cache = SpecializationCache::Incompatible;
574+
}
575+
}
576+
}
577+
SpecializationCache::Incompatible => {}
578+
}
579+
self.patterns.push(row);
526580
}
527581
}
528582

529583
/// Iterate over the first component of each row
530584
fn heads<'a>(&'a self) -> impl Iterator<Item = &'a Pat<'tcx>> + Captures<'p> {
531-
self.0.iter().map(|r| r.head())
585+
self.patterns.iter().map(|r| r.head())
532586
}
533587

534588
/// This computes `D(self)`. See top of the file for explanations.
535589
fn specialize_wildcard(&self) -> Self {
536-
self.0.iter().filter_map(|r| r.specialize_wildcard()).collect()
590+
match &self.cache {
591+
SpecializationCache::Variants { wilds, .. } => {
592+
wilds.iter().filter_map(|&i| self.patterns[i].specialize_wildcard()).collect()
593+
}
594+
SpecializationCache::Incompatible => {
595+
self.patterns.iter().filter_map(|r| r.specialize_wildcard()).collect()
596+
}
597+
}
537598
}
538599

539600
/// This computes `S(constructor, self)`. See top of the file for explanations.
@@ -543,10 +604,31 @@ impl<'p, 'tcx> Matrix<'p, 'tcx> {
543604
constructor: &Constructor<'tcx>,
544605
ctor_wild_subpatterns: &Fields<'p, 'tcx>,
545606
) -> Matrix<'p, 'tcx> {
546-
self.0
547-
.iter()
548-
.filter_map(|r| r.specialize_constructor(cx, constructor, ctor_wild_subpatterns))
549-
.collect()
607+
match &self.cache {
608+
SpecializationCache::Variants { lookup, wilds } => {
609+
if let Constructor::Variant(id) = constructor {
610+
lookup
611+
.get(id)
612+
.unwrap_or(&wilds)
613+
.iter()
614+
.filter_map(|&i| {
615+
self.patterns[i].specialize_constructor(
616+
cx,
617+
constructor,
618+
ctor_wild_subpatterns,
619+
)
620+
})
621+
.collect()
622+
} else {
623+
unreachable!()
624+
}
625+
}
626+
SpecializationCache::Incompatible => self
627+
.patterns
628+
.iter()
629+
.filter_map(|r| r.specialize_constructor(cx, constructor, ctor_wild_subpatterns))
630+
.collect(),
631+
}
550632
}
551633
}
552634

@@ -568,7 +650,7 @@ impl<'p, 'tcx> fmt::Debug for Matrix<'p, 'tcx> {
568650
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
569651
write!(f, "\n")?;
570652

571-
let &Matrix(ref m) = self;
653+
let Matrix { patterns: m, .. } = self;
572654
let pretty_printed_matrix: Vec<Vec<String>> =
573655
m.iter().map(|row| row.iter().map(|pat| format!("{:?}", pat)).collect()).collect();
574656

@@ -1824,7 +1906,7 @@ crate fn is_useful<'p, 'tcx>(
18241906
is_under_guard: bool,
18251907
is_top_level: bool,
18261908
) -> Usefulness<'tcx> {
1827-
let &Matrix(ref rows) = matrix;
1909+
let Matrix { patterns: rows, .. } = matrix;
18281910
debug!("is_useful({:#?}, {:#?})", matrix, v);
18291911

18301912
// The base case. We are pattern-matching on () and the return value is
@@ -2266,7 +2348,7 @@ fn split_grouped_constructors<'p, 'tcx>(
22662348
// `borders` is the set of borders between equivalence classes: each equivalence
22672349
// class lies between 2 borders.
22682350
let row_borders = matrix
2269-
.0
2351+
.patterns
22702352
.iter()
22712353
.flat_map(|row| {
22722354
IntRange::from_pat(tcx, param_env, row.head()).map(|r| (r, row.len()))

0 commit comments

Comments
 (0)