Skip to content

Commit e00cdd7

Browse files
committed
Improve time complexity of equality relations
This PR adds a `UnificationTable` to the `TypeVariableTable` type which is used store information about variable equality instead of just storing them in a vector for later processing. By using a `UnificationTable` equality relations can be resolved in O(n) (for all realistic values of n) rather than O(n!) which can give massive speedups in certain cases (see combine as an example). Link to combine: https://github.com/Marwes/combine
1 parent 6d262db commit e00cdd7

File tree

12 files changed

+183
-42
lines changed

12 files changed

+183
-42
lines changed

Diff for: src/librustc/middle/infer/bivariate.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ impl<'a, 'tcx> TypeRelation<'a, 'tcx> for Bivariate<'a, 'tcx> {
7777
if a == b { return Ok(a); }
7878

7979
let infcx = self.fields.infcx;
80-
let a = infcx.type_variables.borrow().replace_if_possible(a);
81-
let b = infcx.type_variables.borrow().replace_if_possible(b);
80+
let a = infcx.type_variables.borrow_mut().replace_if_possible(a);
81+
let b = infcx.type_variables.borrow_mut().replace_if_possible(b);
8282
match (&a.sty, &b.sty) {
8383
(&ty::TyInfer(TyVar(a_id)), &ty::TyInfer(TyVar(b_id))) => {
8484
infcx.type_variables.borrow_mut().relate_vars(a_id, BiTo, b_id);

Diff for: src/librustc/middle/infer/combine.rs

+13-3
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,12 @@ impl<'a, 'tcx> CombineFields<'a, 'tcx> {
210210
None => break,
211211
Some(e) => e,
212212
};
213+
// Get the actual variable that b_vid has been inferred to
214+
let (b_vid, b_ty) = {
215+
let mut variables = self.infcx.type_variables.borrow_mut();
216+
let b_vid = variables.root_var(b_vid);
217+
(b_vid, variables.probe_root(b_vid))
218+
};
213219

214220
debug!("instantiate(a_ty={:?} dir={:?} b_vid={:?})",
215221
a_ty,
@@ -219,7 +225,6 @@ impl<'a, 'tcx> CombineFields<'a, 'tcx> {
219225
// Check whether `vid` has been instantiated yet. If not,
220226
// make a generalized form of `ty` and instantiate with
221227
// that.
222-
let b_ty = self.infcx.type_variables.borrow().probe(b_vid);
223228
let b_ty = match b_ty {
224229
Some(t) => t, // ...already instantiated.
225230
None => { // ...not yet instantiated:
@@ -307,12 +312,17 @@ impl<'cx, 'tcx> ty::fold::TypeFolder<'tcx> for Generalizer<'cx, 'tcx> {
307312
// where `$1` has already been instantiated with `Box<$0>`)
308313
match t.sty {
309314
ty::TyInfer(ty::TyVar(vid)) => {
315+
let mut variables = self.infcx.type_variables.borrow_mut();
316+
let vid = variables.root_var(vid);
310317
if vid == self.for_vid {
311318
self.cycle_detected = true;
312319
self.tcx().types.err
313320
} else {
314-
match self.infcx.type_variables.borrow().probe(vid) {
315-
Some(u) => self.fold_ty(u),
321+
match variables.probe_root(vid) {
322+
Some(u) => {
323+
drop(variables);
324+
self.fold_ty(u)
325+
}
316326
None => t,
317327
}
318328
}

Diff for: src/librustc/middle/infer/equate.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ impl<'a, 'tcx> TypeRelation<'a,'tcx> for Equate<'a, 'tcx> {
5050
if a == b { return Ok(a); }
5151

5252
let infcx = self.fields.infcx;
53-
let a = infcx.type_variables.borrow().replace_if_possible(a);
54-
let b = infcx.type_variables.borrow().replace_if_possible(b);
53+
let a = infcx.type_variables.borrow_mut().replace_if_possible(a);
54+
let b = infcx.type_variables.borrow_mut().replace_if_possible(b);
5555
match (&a.sty, &b.sty) {
5656
(&ty::TyInfer(TyVar(a_id)), &ty::TyInfer(TyVar(b_id))) => {
5757
infcx.type_variables.borrow_mut().relate_vars(a_id, EqTo, b_id);

Diff for: src/librustc/middle/infer/freshen.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,9 @@ impl<'a, 'tcx> TypeFolder<'tcx> for TypeFreshener<'a, 'tcx> {
111111

112112
match t.sty {
113113
ty::TyInfer(ty::TyVar(v)) => {
114+
let opt_ty = self.infcx.type_variables.borrow_mut().probe(v);
114115
self.freshen(
115-
self.infcx.type_variables.borrow().probe(v),
116+
opt_ty,
116117
ty::TyVar(v),
117118
ty::FreshTy)
118119
}

Diff for: src/librustc/middle/infer/higher_ranked/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ impl<'a,'tcx> InferCtxtExt for InferCtxt<'a,'tcx> {
434434
self.region_vars.vars_created_since_snapshot(&snapshot.region_vars_snapshot);
435435

436436
let escaping_types =
437-
self.type_variables.borrow().types_escaping_snapshot(&snapshot.type_snapshot);
437+
self.type_variables.borrow_mut().types_escaping_snapshot(&snapshot.type_snapshot);
438438

439439
let mut escaping_region_vars = FnvHashSet();
440440
for ty in &escaping_types {

Diff for: src/librustc/middle/infer/lattice.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ pub fn super_lattice_tys<'a,'tcx,L:LatticeDir<'a,'tcx>>(this: &mut L,
6060
}
6161

6262
let infcx = this.infcx();
63-
let a = infcx.type_variables.borrow().replace_if_possible(a);
64-
let b = infcx.type_variables.borrow().replace_if_possible(b);
63+
let a = infcx.type_variables.borrow_mut().replace_if_possible(a);
64+
let b = infcx.type_variables.borrow_mut().replace_if_possible(b);
6565
match (&a.sty, &b.sty) {
6666
(&ty::TyInfer(TyVar(..)), &ty::TyInfer(TyVar(..)))
6767
if infcx.type_var_diverges(a) && infcx.type_var_diverges(b) => {

Diff for: src/librustc/middle/infer/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
637637
let mut variables = Vec::new();
638638

639639
let unbound_ty_vars = self.type_variables
640-
.borrow()
640+
.borrow_mut()
641641
.unsolved_variables()
642642
.into_iter()
643643
.map(|t| self.tcx.mk_var(t));
@@ -1162,7 +1162,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
11621162
// structurally), and we prevent cycles in any case,
11631163
// so this recursion should always be of very limited
11641164
// depth.
1165-
self.type_variables.borrow()
1165+
self.type_variables.borrow_mut()
11661166
.probe(v)
11671167
.map(|t| self.shallow_resolve(t))
11681168
.unwrap_or(typ)

Diff for: src/librustc/middle/infer/sub.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ impl<'a, 'tcx> TypeRelation<'a, 'tcx> for Sub<'a, 'tcx> {
6565
if a == b { return Ok(a); }
6666

6767
let infcx = self.fields.infcx;
68-
let a = infcx.type_variables.borrow().replace_if_possible(a);
69-
let b = infcx.type_variables.borrow().replace_if_possible(b);
68+
let a = infcx.type_variables.borrow_mut().replace_if_possible(a);
69+
let b = infcx.type_variables.borrow_mut().replace_if_possible(b);
7070
match (&a.sty, &b.sty) {
7171
(&ty::TyInfer(TyVar(a_id)), &ty::TyInfer(TyVar(b_id))) => {
7272
infcx.type_variables

Diff for: src/librustc/middle/infer/type_variable.rs

+80-18
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ use std::marker::PhantomData;
2020
use std::mem;
2121
use std::u32;
2222
use rustc_data_structures::snapshot_vec as sv;
23+
use rustc_data_structures::unify as ut;
2324

2425
pub struct TypeVariableTable<'tcx> {
2526
values: sv::SnapshotVec<Delegate<'tcx>>,
27+
eq_relations: ut::UnificationTable<ty::TyVid>,
2628
}
2729

2830
struct TypeVariableData<'tcx> {
@@ -50,20 +52,22 @@ pub struct Default<'tcx> {
5052
}
5153

5254
pub struct Snapshot {
53-
snapshot: sv::Snapshot
55+
snapshot: sv::Snapshot,
56+
eq_snapshot: ut::Snapshot<ty::TyVid>,
5457
}
5558

5659
enum UndoEntry<'tcx> {
5760
// The type of the var was specified.
5861
SpecifyVar(ty::TyVid, Vec<Relation>, Option<Default<'tcx>>),
5962
Relate(ty::TyVid, ty::TyVid),
63+
RelateRange(ty::TyVid, usize),
6064
}
6165

6266
struct Delegate<'tcx>(PhantomData<&'tcx ()>);
6367

6468
type Relation = (RelationDir, ty::TyVid);
6569

66-
#[derive(Copy, Clone, PartialEq, Debug)]
70+
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
6771
pub enum RelationDir {
6872
SubtypeOf, SupertypeOf, EqTo, BiTo
6973
}
@@ -81,7 +85,10 @@ impl RelationDir {
8185

8286
impl<'tcx> TypeVariableTable<'tcx> {
8387
pub fn new() -> TypeVariableTable<'tcx> {
84-
TypeVariableTable { values: sv::SnapshotVec::new() }
88+
TypeVariableTable {
89+
values: sv::SnapshotVec::new(),
90+
eq_relations: ut::UnificationTable::new(),
91+
}
8592
}
8693

8794
fn relations<'a>(&'a mut self, a: ty::TyVid) -> &'a mut Vec<Relation> {
@@ -103,22 +110,48 @@ impl<'tcx> TypeVariableTable<'tcx> {
103110
///
104111
/// Precondition: neither `a` nor `b` are known.
105112
pub fn relate_vars(&mut self, a: ty::TyVid, dir: RelationDir, b: ty::TyVid) {
113+
let a = self.root_var(a);
114+
let b = self.root_var(b);
106115
if a != b {
107-
self.relations(a).push((dir, b));
108-
self.relations(b).push((dir.opposite(), a));
109-
self.values.record(Relate(a, b));
116+
if dir == EqTo {
117+
// a and b must be equal which we mark in the unification table
118+
let root = self.eq_relations.union(a, b);
119+
// In addition to being equal, all relations from the variable which is no longer
120+
// the root must be added to the root so they are not forgotten as the other
121+
// variable should no longer be referenced (other than to get the root)
122+
let other = if a == root { b } else { a };
123+
let count = {
124+
let (relations, root_relations) = if other.index < root.index {
125+
let (pre, post) = self.values.split_at_mut(root.index as usize);
126+
(relations(&mut pre[other.index as usize]), relations(&mut post[0]))
127+
} else {
128+
let (pre, post) = self.values.split_at_mut(other.index as usize);
129+
(relations(&mut post[0]), relations(&mut pre[root.index as usize]))
130+
};
131+
root_relations.extend_from_slice(relations);
132+
relations.len()
133+
};
134+
self.values.record(RelateRange(root, count));
135+
} else {
136+
self.relations(a).push((dir, b));
137+
self.relations(b).push((dir.opposite(), a));
138+
self.values.record(Relate(a, b));
139+
}
110140
}
111141
}
112142

113143
/// Instantiates `vid` with the type `ty` and then pushes an entry onto `stack` for each of the
114144
/// relations of `vid` to other variables. The relations will have the form `(ty, dir, vid1)`
115145
/// where `vid1` is some other variable id.
146+
///
147+
/// Precondition: `vid` must be a root in the unification table
116148
pub fn instantiate_and_push(
117149
&mut self,
118150
vid: ty::TyVid,
119151
ty: Ty<'tcx>,
120152
stack: &mut Vec<(Ty<'tcx>, RelationDir, ty::TyVid)>)
121153
{
154+
debug_assert!(self.root_var(vid) == vid);
122155
let old_value = {
123156
let value_ptr = &mut self.values.get_mut(vid.index as usize).value;
124157
mem::replace(value_ptr, Known(ty))
@@ -140,21 +173,33 @@ impl<'tcx> TypeVariableTable<'tcx> {
140173
pub fn new_var(&mut self,
141174
diverging: bool,
142175
default: Option<Default<'tcx>>) -> ty::TyVid {
176+
self.eq_relations.new_key(());
143177
let index = self.values.push(TypeVariableData {
144178
value: Bounded { relations: vec![], default: default },
145179
diverging: diverging
146180
});
147181
ty::TyVid { index: index as u32 }
148182
}
149183

150-
pub fn probe(&self, vid: ty::TyVid) -> Option<Ty<'tcx>> {
184+
pub fn root_var(&mut self, vid: ty::TyVid) -> ty::TyVid {
185+
self.eq_relations.find(vid)
186+
}
187+
188+
pub fn probe(&mut self, vid: ty::TyVid) -> Option<Ty<'tcx>> {
189+
let vid = self.root_var(vid);
190+
self.probe_root(vid)
191+
}
192+
193+
/// Retrieves the type of `vid` given that it is currently a root in the unification table
194+
pub fn probe_root(&mut self, vid: ty::TyVid) -> Option<Ty<'tcx>> {
195+
debug_assert!(self.root_var(vid) == vid);
151196
match self.values.get(vid.index as usize).value {
152197
Bounded { .. } => None,
153198
Known(t) => Some(t)
154199
}
155200
}
156201

157-
pub fn replace_if_possible(&self, t: Ty<'tcx>) -> Ty<'tcx> {
202+
pub fn replace_if_possible(&mut self, t: Ty<'tcx>) -> Ty<'tcx> {
158203
match t.sty {
159204
ty::TyInfer(ty::TyVar(v)) => {
160205
match self.probe(v) {
@@ -167,18 +212,23 @@ impl<'tcx> TypeVariableTable<'tcx> {
167212
}
168213

169214
pub fn snapshot(&mut self) -> Snapshot {
170-
Snapshot { snapshot: self.values.start_snapshot() }
215+
Snapshot {
216+
snapshot: self.values.start_snapshot(),
217+
eq_snapshot: self.eq_relations.snapshot(),
218+
}
171219
}
172220

173221
pub fn rollback_to(&mut self, s: Snapshot) {
174222
self.values.rollback_to(s.snapshot);
223+
self.eq_relations.rollback_to(s.eq_snapshot);
175224
}
176225

177226
pub fn commit(&mut self, s: Snapshot) {
178227
self.values.commit(s.snapshot);
228+
self.eq_relations.commit(s.eq_snapshot);
179229
}
180230

181-
pub fn types_escaping_snapshot(&self, s: &Snapshot) -> Vec<Ty<'tcx>> {
231+
pub fn types_escaping_snapshot(&mut self, s: &Snapshot) -> Vec<Ty<'tcx>> {
182232
/*!
183233
* Find the set of type variables that existed *before* `s`
184234
* but which have only been unified since `s` started, and
@@ -208,7 +258,10 @@ impl<'tcx> TypeVariableTable<'tcx> {
208258
if vid.index < new_elem_threshold {
209259
// quick check to see if this variable was
210260
// created since the snapshot started or not.
211-
let escaping_type = self.probe(vid).unwrap();
261+
let escaping_type = match self.values.get(vid.index as usize).value {
262+
Bounded { .. } => unreachable!(),
263+
Known(ty) => ty,
264+
};
212265
escaping_types.push(escaping_type);
213266
}
214267
debug!("SpecifyVar({:?}) new_elem_threshold={}", vid, new_elem_threshold);
@@ -221,13 +274,15 @@ impl<'tcx> TypeVariableTable<'tcx> {
221274
escaping_types
222275
}
223276

224-
pub fn unsolved_variables(&self) -> Vec<ty::TyVid> {
225-
self.values
226-
.iter()
227-
.enumerate()
228-
.filter_map(|(i, value)| match &value.value {
229-
&TypeVariableValue::Known(_) => None,
230-
&TypeVariableValue::Bounded { .. } => Some(ty::TyVid { index: i as u32 })
277+
pub fn unsolved_variables(&mut self) -> Vec<ty::TyVid> {
278+
(0..self.values.len())
279+
.filter_map(|i| {
280+
let vid = ty::TyVid { index: i as u32 };
281+
if self.probe(vid).is_some() {
282+
None
283+
} else {
284+
Some(vid)
285+
}
231286
})
232287
.collect()
233288
}
@@ -250,6 +305,13 @@ impl<'tcx> sv::SnapshotVecDelegate for Delegate<'tcx> {
250305
relations(&mut (*values)[a.index as usize]).pop();
251306
relations(&mut (*values)[b.index as usize]).pop();
252307
}
308+
309+
RelateRange(i, n) => {
310+
let relations = relations(&mut (*values)[i.index as usize]);
311+
for _ in 0..n {
312+
relations.pop();
313+
}
314+
}
253315
}
254316
}
255317
}

Diff for: src/librustc/middle/infer/unify_key.rs

+7
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,10 @@ impl<'tcx> ToType<'tcx> for ast::FloatTy {
7373
tcx.mk_mach_float(*self)
7474
}
7575
}
76+
77+
impl UnifyKey for ty::TyVid {
78+
type Value = ();
79+
fn index(&self) -> u32 { self.index }
80+
fn from_index(i: u32) -> ty::TyVid { ty::TyVid { index: i } }
81+
fn tag(_: Option<ty::TyVid>) -> &'static str { "TyVid" }
82+
}

0 commit comments

Comments
 (0)