Skip to content

Commit e599b53

Browse files
committed
Auto merge of rust-lang#76918 - ishitatsuyuki:match-fastpath, r=oli-obk
Add fast path for match checking This adds a fast path that would reduce the complexity to linear on matches consisting of only variant patterns (i.e. enum matches). (Also see: rust-lang#7462) Unfortunately, I was too lazy to add a similar fast path for constants (mostly for integer matches), ideally that could be added another day. TBH, I'm not confident with the performance claims due to the fact that enums tends to be small and FxHashMap could add a lot of overhead. r? `@Mark-Simulacrum` needs perf
2 parents 87d262a + 01a771a commit e599b53

File tree

2 files changed

+165
-22
lines changed

2 files changed

+165
-22
lines changed

compiler/rustc_mir_build/src/thir/pattern/_match.rs

+161-18
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,10 @@
139139
//!
140140
//! It is computed as follows. We look at the pattern `p_1` on top of the stack,
141141
//! and we have three cases:
142-
//! 1.1. `p_1 = c(r_1, .., r_a)`. We discard the current stack and return nothing.
143-
//! 1.2. `p_1 = _`. We return the rest of the stack:
142+
//! 2.1. `p_1 = c(r_1, .., r_a)`. We discard the current stack and return nothing.
143+
//! 2.2. `p_1 = _`. We return the rest of the stack:
144144
//! p_2, .., p_n
145-
//! 1.3. `p_1 = r_1 | r_2`. We expand the OR-pattern and then recurse on each resulting
145+
//! 2.3. `p_1 = r_1 | r_2`. We expand the OR-pattern and then recurse on each resulting
146146
//! stack.
147147
//! D((r_1, p_2, .., p_n))
148148
//! D((r_2, p_2, .., p_n))
@@ -276,7 +276,7 @@ use self::Usefulness::*;
276276
use self::WitnessPreference::*;
277277

278278
use rustc_data_structures::captures::Captures;
279-
use rustc_data_structures::fx::FxHashSet;
279+
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
280280
use rustc_index::vec::Idx;
281281

282282
use super::{compare_const_vals, PatternFoldable, PatternFolder};
@@ -416,7 +416,7 @@ impl<'tcx> Pat<'tcx> {
416416

417417
/// A row of a matrix. Rows of len 1 are very common, which is why `SmallVec[_; 2]`
418418
/// works well.
419-
#[derive(Debug, Clone)]
419+
#[derive(Debug, Clone, PartialEq)]
420420
crate struct PatStack<'p, 'tcx>(SmallVec<[&'p Pat<'tcx>; 2]>);
421421

422422
impl<'p, 'tcx> PatStack<'p, 'tcx> {
@@ -504,13 +504,36 @@ impl<'p, 'tcx> FromIterator<&'p Pat<'tcx>> for PatStack<'p, 'tcx> {
504504
}
505505
}
506506

507+
/// Depending on the match patterns, the specialization process might be able to use a fast path.
508+
/// Tracks whether we can use the fast path and the lookup table needed in those cases.
509+
#[derive(Clone, Debug, PartialEq)]
510+
enum SpecializationCache {
511+
/// Patterns consist of only enum variants.
512+
/// Variant patterns does not intersect with each other (in contrast to range patterns),
513+
/// so it is possible to precompute the result of `Matrix::specialize_constructor` at a
514+
/// lower computational complexity.
515+
/// `lookup` is responsible for holding the precomputed result of
516+
/// `Matrix::specialize_constructor`, while `wilds` is used for two purposes: the first one is
517+
/// the precomputed result of `Matrix::specialize_wildcard`, and the second is to be used as a
518+
/// fallback for `Matrix::specialize_constructor` when it tries to apply a constructor that
519+
/// has not been seen in the `Matrix`. See `update_cache` for further explanations.
520+
Variants { lookup: FxHashMap<DefId, SmallVec<[usize; 1]>>, wilds: SmallVec<[usize; 1]> },
521+
/// Does not belong to the cases above, use the slow path.
522+
Incompatible,
523+
}
524+
507525
/// A 2D matrix.
508-
#[derive(Clone)]
509-
crate struct Matrix<'p, 'tcx>(Vec<PatStack<'p, 'tcx>>);
526+
#[derive(Clone, PartialEq)]
527+
crate struct Matrix<'p, 'tcx> {
528+
patterns: Vec<PatStack<'p, 'tcx>>,
529+
cache: SpecializationCache,
530+
}
510531

511532
impl<'p, 'tcx> Matrix<'p, 'tcx> {
512533
crate fn empty() -> Self {
513-
Matrix(vec![])
534+
// Use `SpecializationCache::Incompatible` as a placeholder; we will initialize it on the
535+
// first call to `push`. See the first half of `update_cache`.
536+
Matrix { patterns: vec![], cache: SpecializationCache::Incompatible }
514537
}
515538

516539
/// Pushes a new row to the matrix. If the row starts with an or-pattern, this expands it.
@@ -522,18 +545,101 @@ impl<'p, 'tcx> Matrix<'p, 'tcx> {
522545
self.push(row)
523546
}
524547
} else {
525-
self.0.push(row);
548+
self.patterns.push(row);
549+
self.update_cache(self.patterns.len() - 1);
550+
}
551+
}
552+
553+
fn update_cache(&mut self, idx: usize) {
554+
let row = &self.patterns[idx];
555+
// We don't know which kind of cache could be used until we see the first row; therefore an
556+
// empty `Matrix` is initialized with `SpecializationCache::Empty`, then the cache is
557+
// assigned the appropriate variant below on the first call to `push`.
558+
if self.patterns.is_empty() {
559+
self.cache = if row.is_empty() {
560+
SpecializationCache::Incompatible
561+
} else {
562+
match *row.head().kind {
563+
PatKind::Variant { .. } => SpecializationCache::Variants {
564+
lookup: FxHashMap::default(),
565+
wilds: SmallVec::new(),
566+
},
567+
// Note: If the first pattern is a wildcard, then all patterns after that is not
568+
// useful. The check is simple enough so we treat it as the same as unsupported
569+
// patterns.
570+
_ => SpecializationCache::Incompatible,
571+
}
572+
};
573+
}
574+
// Update the cache.
575+
match &mut self.cache {
576+
SpecializationCache::Variants { ref mut lookup, ref mut wilds } => {
577+
let head = row.head();
578+
match *head.kind {
579+
_ if head.is_wildcard() => {
580+
// Per rule 1.3 in the top-level comments, a wildcard pattern is included in
581+
// the result of `specialize_constructor` for *any* `Constructor`.
582+
// We push the wildcard pattern to the precomputed result for constructors
583+
// that we have seen before; results for constructors we have not yet seen
584+
// defaults to `wilds`, which is updated right below.
585+
for (_, v) in lookup.iter_mut() {
586+
v.push(idx);
587+
}
588+
// Per rule 2.1 and 2.2 in the top-level comments, only wildcard patterns
589+
// are included in the result of `specialize_wildcard`.
590+
// What we do here is to track the wildcards we have seen; so in addition to
591+
// acting as the precomputed result of `specialize_wildcard`, `wilds` also
592+
// serves as the default value of `specialize_constructor` for constructors
593+
// that are not in `lookup`.
594+
wilds.push(idx);
595+
}
596+
PatKind::Variant { adt_def, variant_index, .. } => {
597+
// Handle the cases of rule 1.1 and 1.2 in the top-level comments.
598+
// A variant pattern can only be included in the results of
599+
// `specialize_constructor` for a particular constructor, therefore we are
600+
// using a HashMap to track that.
601+
lookup
602+
.entry(adt_def.variants[variant_index].def_id)
603+
// Default to `wilds` for absent keys. See above for an explanation.
604+
.or_insert_with(|| wilds.clone())
605+
.push(idx);
606+
}
607+
_ => {
608+
self.cache = SpecializationCache::Incompatible;
609+
}
610+
}
611+
}
612+
SpecializationCache::Incompatible => {}
526613
}
527614
}
528615

529616
/// Iterate over the first component of each row
530617
fn heads<'a>(&'a self) -> impl Iterator<Item = &'a Pat<'tcx>> + Captures<'p> {
531-
self.0.iter().map(|r| r.head())
618+
self.patterns.iter().map(|r| r.head())
532619
}
533620

534621
/// This computes `D(self)`. See top of the file for explanations.
535622
fn specialize_wildcard(&self) -> Self {
536-
self.0.iter().filter_map(|r| r.specialize_wildcard()).collect()
623+
match &self.cache {
624+
SpecializationCache::Variants { wilds, .. } => {
625+
let result =
626+
wilds.iter().filter_map(|&i| self.patterns[i].specialize_wildcard()).collect();
627+
// When debug assertions are enabled, check the results against the "slow path"
628+
// result.
629+
debug_assert_eq!(
630+
result,
631+
Self {
632+
patterns: self.patterns.clone(),
633+
cache: SpecializationCache::Incompatible
634+
}
635+
.specialize_wildcard()
636+
);
637+
result
638+
}
639+
SpecializationCache::Incompatible => {
640+
self.patterns.iter().filter_map(|r| r.specialize_wildcard()).collect()
641+
}
642+
}
537643
}
538644

539645
/// This computes `S(constructor, self)`. See top of the file for explanations.
@@ -543,10 +649,47 @@ impl<'p, 'tcx> Matrix<'p, 'tcx> {
543649
constructor: &Constructor<'tcx>,
544650
ctor_wild_subpatterns: &Fields<'p, 'tcx>,
545651
) -> Matrix<'p, 'tcx> {
546-
self.0
547-
.iter()
548-
.filter_map(|r| r.specialize_constructor(cx, constructor, ctor_wild_subpatterns))
549-
.collect()
652+
match &self.cache {
653+
SpecializationCache::Variants { lookup, wilds } => {
654+
let result: Self = if let Constructor::Variant(id) = constructor {
655+
lookup
656+
.get(id)
657+
// Default to `wilds` for absent keys. See `update_cache` for an explanation.
658+
.unwrap_or(&wilds)
659+
.iter()
660+
.filter_map(|&i| {
661+
self.patterns[i].specialize_constructor(
662+
cx,
663+
constructor,
664+
ctor_wild_subpatterns,
665+
)
666+
})
667+
.collect()
668+
} else {
669+
unreachable!()
670+
};
671+
// When debug assertions are enabled, check the results against the "slow path"
672+
// result.
673+
debug_assert_eq!(
674+
result,
675+
Matrix {
676+
patterns: self.patterns.clone(),
677+
cache: SpecializationCache::Incompatible
678+
}
679+
.specialize_constructor(
680+
cx,
681+
constructor,
682+
ctor_wild_subpatterns
683+
)
684+
);
685+
result
686+
}
687+
SpecializationCache::Incompatible => self
688+
.patterns
689+
.iter()
690+
.filter_map(|r| r.specialize_constructor(cx, constructor, ctor_wild_subpatterns))
691+
.collect(),
692+
}
550693
}
551694
}
552695

@@ -568,7 +711,7 @@ impl<'p, 'tcx> fmt::Debug for Matrix<'p, 'tcx> {
568711
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
569712
write!(f, "\n")?;
570713

571-
let &Matrix(ref m) = self;
714+
let Matrix { patterns: m, .. } = self;
572715
let pretty_printed_matrix: Vec<Vec<String>> =
573716
m.iter().map(|row| row.iter().map(|pat| format!("{:?}", pat)).collect()).collect();
574717

@@ -1824,7 +1967,7 @@ crate fn is_useful<'p, 'tcx>(
18241967
is_under_guard: bool,
18251968
is_top_level: bool,
18261969
) -> Usefulness<'tcx> {
1827-
let &Matrix(ref rows) = matrix;
1970+
let Matrix { patterns: rows, .. } = matrix;
18281971
debug!("is_useful({:#?}, {:#?})", matrix, v);
18291972

18301973
// The base case. We are pattern-matching on () and the return value is
@@ -2266,7 +2409,7 @@ fn split_grouped_constructors<'p, 'tcx>(
22662409
// `borders` is the set of borders between equivalence classes: each equivalence
22672410
// class lies between 2 borders.
22682411
let row_borders = matrix
2269-
.0
2412+
.patterns
22702413
.iter()
22712414
.flat_map(|row| {
22722415
IntRange::from_pat(tcx, param_env, row.head()).map(|r| (r, row.len()))

compiler/rustc_mir_build/src/thir/pattern/mod.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,19 @@ crate enum PatternError {
3939
NonConstPath(Span),
4040
}
4141

42-
#[derive(Copy, Clone, Debug)]
42+
#[derive(Copy, Clone, Debug, PartialEq)]
4343
crate enum BindingMode {
4444
ByValue,
4545
ByRef(BorrowKind),
4646
}
4747

48-
#[derive(Clone, Debug)]
48+
#[derive(Clone, Debug, PartialEq)]
4949
crate struct FieldPat<'tcx> {
5050
crate field: Field,
5151
crate pattern: Pat<'tcx>,
5252
}
5353

54-
#[derive(Clone, Debug)]
54+
#[derive(Clone, Debug, PartialEq)]
5555
crate struct Pat<'tcx> {
5656
crate ty: Ty<'tcx>,
5757
crate span: Span,
@@ -116,7 +116,7 @@ crate struct Ascription<'tcx> {
116116
crate user_ty_span: Span,
117117
}
118118

119-
#[derive(Clone, Debug)]
119+
#[derive(Clone, Debug, PartialEq)]
120120
crate enum PatKind<'tcx> {
121121
Wild,
122122

0 commit comments

Comments
 (0)