Skip to content

Commit 9ef8f26

Browse files
committed
[red-knot] Add Type.definition method
1 parent 24498e3 commit 9ef8f26

File tree

10 files changed

+270
-226
lines changed

10 files changed

+270
-226
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/red_knot_ide/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ ruff_python_parser = { workspace = true }
1717
ruff_text_size = { workspace = true }
1818
red_knot_python_semantic = { workspace = true }
1919

20+
rustc-hash = { workspace = true }
2021
salsa = { workspace = true }
2122
smallvec = { workspace = true }
2223
tracing = { workspace = true }

crates/red_knot_ide/src/goto.rs

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::find_node::covering_node;
22
use crate::{Db, HasNavigationTargets, NavigationTargets, RangedValue};
3+
use red_knot_python_semantic::types::Type;
34
use red_knot_python_semantic::{HasType, SemanticModel};
45
use ruff_db::files::{File, FileRange};
56
use ruff_db::parsed::{parsed_module, ParsedModule};
@@ -16,40 +17,18 @@ pub fn goto_type_definition(
1617
let goto_target = find_goto_target(parsed, offset)?;
1718

1819
let model = SemanticModel::new(db.upcast(), file);
19-
20-
let ty = match goto_target {
21-
GotoTarget::Expression(expression) => expression.inferred_type(&model),
22-
GotoTarget::FunctionDef(function) => function.inferred_type(&model),
23-
GotoTarget::ClassDef(class) => class.inferred_type(&model),
24-
GotoTarget::Parameter(parameter) => parameter.inferred_type(&model),
25-
GotoTarget::Alias(alias) => alias.inferred_type(&model),
26-
GotoTarget::ExceptVariable(except) => except.inferred_type(&model),
27-
GotoTarget::KeywordArgument(argument) => {
28-
// TODO: Pyright resolves the declared type of the matching parameter. This seems more accurate
29-
// than using the inferred value.
30-
argument.value.inferred_type(&model)
31-
}
32-
// TODO: Support identifier targets
33-
GotoTarget::PatternMatchRest(_)
34-
| GotoTarget::PatternKeywordArgument(_)
35-
| GotoTarget::PatternMatchStarName(_)
36-
| GotoTarget::PatternMatchAsName(_)
37-
| GotoTarget::ImportedModule(_)
38-
| GotoTarget::TypeParamTypeVarName(_)
39-
| GotoTarget::TypeParamParamSpecName(_)
40-
| GotoTarget::TypeParamTypeVarTupleName(_)
41-
| GotoTarget::NonLocal { .. }
42-
| GotoTarget::Globals { .. } => return None,
43-
};
20+
let ty = goto_target.inferred_type(&model)?;
4421

4522
tracing::debug!(
4623
"Inferred type of covering node is {}",
4724
ty.display(db.upcast())
4825
);
4926

27+
let navigation_targets = ty.navigation_targets(db);
28+
5029
Some(RangedValue {
5130
range: FileRange::new(file, goto_target.range()),
52-
value: ty.navigation_targets(db),
31+
value: navigation_targets,
5332
})
5433
}
5534

@@ -173,6 +152,37 @@ impl Ranged for GotoTarget<'_> {
173152
}
174153
}
175154

155+
impl GotoTarget<'_> {
156+
pub(crate) fn inferred_type<'db>(&self, model: &SemanticModel<'db>) -> Option<Type<'db>> {
157+
let ty = match self {
158+
GotoTarget::Expression(expression) => expression.inferred_type(model),
159+
GotoTarget::FunctionDef(function) => function.inferred_type(model),
160+
GotoTarget::ClassDef(class) => class.inferred_type(model),
161+
GotoTarget::Parameter(parameter) => parameter.inferred_type(model),
162+
GotoTarget::Alias(alias) => alias.inferred_type(model),
163+
GotoTarget::ExceptVariable(except) => except.inferred_type(model),
164+
GotoTarget::KeywordArgument(argument) => {
165+
// TODO: Pyright resolves the declared type of the matching parameter. This seems more accurate
166+
// than using the inferred value.
167+
argument.value.inferred_type(model)
168+
}
169+
// TODO: Support identifier targets
170+
GotoTarget::PatternMatchRest(_)
171+
| GotoTarget::PatternKeywordArgument(_)
172+
| GotoTarget::PatternMatchStarName(_)
173+
| GotoTarget::PatternMatchAsName(_)
174+
| GotoTarget::ImportedModule(_)
175+
| GotoTarget::TypeParamTypeVarName(_)
176+
| GotoTarget::TypeParamParamSpecName(_)
177+
| GotoTarget::TypeParamTypeVarTupleName(_)
178+
| GotoTarget::NonLocal { .. }
179+
| GotoTarget::Globals { .. } => return None,
180+
};
181+
182+
Some(ty)
183+
}
184+
}
185+
176186
pub(crate) fn find_goto_target(parsed: &ParsedModule, offset: TextSize) -> Option<GotoTarget> {
177187
let token = parsed.tokens().at_offset(offset).find(|token| {
178188
matches!(
@@ -389,12 +399,12 @@ mod tests {
389399

390400
test.write_file("lib.py", "a = 10").unwrap();
391401

392-
assert_snapshot!(test.goto_type_definition(), @r###"
402+
assert_snapshot!(test.goto_type_definition(), @r"
393403
info: lint:goto-type-definition: Type definition
394404
--> /lib.py:1:1
395405
|
396406
1 | a = 10
397-
| ^
407+
| ^^^^^^
398408
|
399409
info: Source
400410
--> /main.py:4:13
@@ -404,7 +414,7 @@ mod tests {
404414
4 | lib
405415
| ^^^
406416
|
407-
"###);
417+
");
408418
}
409419

410420
#[test]
@@ -755,14 +765,13 @@ f(**kwargs<CURSOR>)
755765

756766
assert_snapshot!(test.goto_type_definition(), @r"
757767
info: lint:goto-type-definition: Type definition
758-
--> stdlib/builtins.pyi:443:7
768+
--> stdlib/types.pyi:677:11
759769
|
760-
441 | def __getitem__(self, key: int, /) -> str | int | None: ...
761-
442 |
762-
443 | class str(Sequence[str]):
763-
| ^^^
764-
444 | @overload
765-
445 | def __new__(cls, object: object = ...) -> Self: ...
770+
675 | if sys.version_info >= (3, 10):
771+
676 | @final
772+
677 | class NoneType:
773+
| ^^^^^^^^
774+
678 | def __bool__(self) -> Literal[False]: ...
766775
|
767776
info: Source
768777
--> /main.py:3:17
@@ -773,13 +782,14 @@ f(**kwargs<CURSOR>)
773782
|
774783
775784
info: lint:goto-type-definition: Type definition
776-
--> stdlib/types.pyi:677:11
785+
--> stdlib/builtins.pyi:443:7
777786
|
778-
675 | if sys.version_info >= (3, 10):
779-
676 | @final
780-
677 | class NoneType:
781-
| ^^^^^^^^
782-
678 | def __bool__(self) -> Literal[False]: ...
787+
441 | def __getitem__(self, key: int, /) -> str | int | None: ...
788+
442 |
789+
443 | class str(Sequence[str]):
790+
| ^^^
791+
444 | @overload
792+
445 | def __new__(cls, object: object = ...) -> Self: ...
783793
|
784794
info: Source
785795
--> /main.py:3:17

crates/red_knot_ide/src/lib.rs

Lines changed: 47 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,14 @@ mod db;
22
mod find_node;
33
mod goto;
44

5-
use std::ops::{Deref, DerefMut};
6-
75
pub use db::Db;
86
pub use goto::goto_type_definition;
9-
use red_knot_python_semantic::types::{
10-
Class, ClassBase, ClassLiteralType, FunctionType, InstanceType, IntersectionType,
11-
KnownInstanceType, ModuleLiteralType, Type,
12-
};
7+
use rustc_hash::FxHashSet;
8+
use std::ops::{Deref, DerefMut};
9+
10+
use red_knot_python_semantic::types::{Type, TypeDefinition};
1311
use ruff_db::files::{File, FileRange};
14-
use ruff_db::source::source_text;
15-
use ruff_text_size::{Ranged, TextLen, TextRange};
12+
use ruff_text_size::{Ranged, TextRange};
1613

1714
/// Information associated with a text range.
1815
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
@@ -54,7 +51,7 @@ where
5451
}
5552

5653
/// Target to which the editor can navigate to.
57-
#[derive(Debug, Clone)]
54+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
5855
pub struct NavigationTarget {
5956
file: File,
6057

@@ -95,6 +92,17 @@ impl NavigationTargets {
9592
Self(smallvec::SmallVec::new())
9693
}
9794

95+
fn unique(targets: impl IntoIterator<Item = NavigationTarget>) -> Self {
96+
let unique: FxHashSet<_> = targets.into_iter().collect();
97+
if unique.is_empty() {
98+
Self::empty()
99+
} else {
100+
let mut targets = unique.into_iter().collect::<Vec<_>>();
101+
targets.sort_by_key(|target| (target.file, target.focus_range.start()));
102+
Self(targets.into())
103+
}
104+
}
105+
98106
fn iter(&self) -> std::slice::Iter<'_, NavigationTarget> {
99107
self.0.iter()
100108
}
@@ -125,7 +133,7 @@ impl<'a> IntoIterator for &'a NavigationTargets {
125133

126134
impl FromIterator<NavigationTarget> for NavigationTargets {
127135
fn from_iter<T: IntoIterator<Item = NavigationTarget>>(iter: T) -> Self {
128-
Self(iter.into_iter().collect())
136+
Self::unique(iter)
129137
}
130138
}
131139

@@ -136,128 +144,46 @@ pub trait HasNavigationTargets {
136144
impl HasNavigationTargets for Type<'_> {
137145
fn navigation_targets(&self, db: &dyn Db) -> NavigationTargets {
138146
match self {
139-
Type::BoundMethod(method) => method.function(db).navigation_targets(db),
140-
Type::FunctionLiteral(function) => function.navigation_targets(db),
141-
Type::ModuleLiteral(module) => module.navigation_targets(db),
142147
Type::Union(union) => union
143148
.iter(db.upcast())
144149
.flat_map(|target| target.navigation_targets(db))
145150
.collect(),
146-
Type::ClassLiteral(class) => class.navigation_targets(db),
147-
Type::Instance(instance) => instance.navigation_targets(db),
148-
Type::KnownInstance(instance) => instance.navigation_targets(db),
149-
Type::SubclassOf(subclass_of_type) => match subclass_of_type.subclass_of() {
150-
ClassBase::Class(class) => class.navigation_targets(db),
151-
ClassBase::Dynamic(_) => NavigationTargets::empty(),
152-
},
153-
154-
Type::StringLiteral(_)
155-
| Type::BooleanLiteral(_)
156-
| Type::LiteralString
157-
| Type::IntLiteral(_)
158-
| Type::BytesLiteral(_)
159-
| Type::SliceLiteral(_)
160-
| Type::MethodWrapper(_)
161-
| Type::WrapperDescriptor(_)
162-
| Type::PropertyInstance(_)
163-
| Type::Tuple(_) => self.to_meta_type(db.upcast()).navigation_targets(db),
164-
165-
Type::Intersection(intersection) => intersection.navigation_targets(db),
166-
167-
Type::Dynamic(_)
168-
| Type::Never
169-
| Type::Callable(_)
170-
| Type::AlwaysTruthy
171-
| Type::AlwaysFalsy => NavigationTargets::empty(),
172-
}
173-
}
174-
}
175-
176-
impl HasNavigationTargets for FunctionType<'_> {
177-
fn navigation_targets(&self, db: &dyn Db) -> NavigationTargets {
178-
let function_range = self.focus_range(db.upcast());
179-
NavigationTargets::single(NavigationTarget {
180-
file: function_range.file(),
181-
focus_range: function_range.range(),
182-
full_range: self.full_range(db.upcast()).range(),
183-
})
184-
}
185-
}
186-
187-
impl HasNavigationTargets for Class<'_> {
188-
fn navigation_targets(&self, db: &dyn Db) -> NavigationTargets {
189-
let class_range = self.focus_range(db.upcast());
190-
NavigationTargets::single(NavigationTarget {
191-
file: class_range.file(),
192-
focus_range: class_range.range(),
193-
full_range: self.full_range(db.upcast()).range(),
194-
})
195-
}
196-
}
197-
198-
impl HasNavigationTargets for ClassLiteralType<'_> {
199-
fn navigation_targets(&self, db: &dyn Db) -> NavigationTargets {
200-
self.class().navigation_targets(db)
201-
}
202-
}
203-
204-
impl HasNavigationTargets for InstanceType<'_> {
205-
fn navigation_targets(&self, db: &dyn Db) -> NavigationTargets {
206-
self.class().navigation_targets(db)
207-
}
208-
}
209-
210-
impl HasNavigationTargets for ModuleLiteralType<'_> {
211-
fn navigation_targets(&self, db: &dyn Db) -> NavigationTargets {
212-
let file = self.module(db).file();
213-
let source = source_text(db.upcast(), file);
214151

215-
NavigationTargets::single(NavigationTarget {
216-
file,
217-
focus_range: TextRange::default(),
218-
full_range: TextRange::up_to(source.text_len()),
219-
})
220-
}
221-
}
222-
223-
impl HasNavigationTargets for KnownInstanceType<'_> {
224-
fn navigation_targets(&self, db: &dyn Db) -> NavigationTargets {
225-
match self {
226-
KnownInstanceType::TypeVar(var) => {
227-
let definition = var.definition(db);
228-
let full_range = definition.full_range(db.upcast());
229-
230-
NavigationTargets::single(NavigationTarget {
231-
file: full_range.file(),
232-
focus_range: definition.focus_range(db.upcast()).range(),
233-
full_range: full_range.range(),
234-
})
152+
Type::Intersection(intersection) => {
153+
// Only consider the positive elements because the negative elements are mainly from narrowing constraints.
154+
let mut targets = intersection
155+
.iter_positive(db.upcast())
156+
.filter(|ty| !ty.is_unknown());
157+
158+
let Some(first) = targets.next() else {
159+
return NavigationTargets::empty();
160+
};
161+
162+
match targets.next() {
163+
Some(_) => {
164+
// If there are multiple types in the intersection, we can't navigate to a single one
165+
// because the type is the intersection of all those types.
166+
NavigationTargets::empty()
167+
}
168+
None => first.navigation_targets(db),
169+
}
235170
}
236171

237-
// TODO: Track the definition of `KnownInstance` and navigate to their definition.
238-
_ => NavigationTargets::empty(),
172+
ty => ty
173+
.definition(db.upcast())
174+
.map(|definition| definition.navigation_targets(db))
175+
.unwrap_or_else(NavigationTargets::empty),
239176
}
240177
}
241178
}
242179

243-
impl HasNavigationTargets for IntersectionType<'_> {
180+
impl HasNavigationTargets for TypeDefinition<'_> {
244181
fn navigation_targets(&self, db: &dyn Db) -> NavigationTargets {
245-
// Only consider the positive elements because the negative elements are mainly from narrowing constraints.
246-
let mut targets = self
247-
.iter_positive(db.upcast())
248-
.filter(|ty| !ty.is_unknown());
249-
250-
let Some(first) = targets.next() else {
251-
return NavigationTargets::empty();
252-
};
253-
254-
match targets.next() {
255-
Some(_) => {
256-
// If there are multiple types in the intersection, we can't navigate to a single one
257-
// because the type is the intersection of all those types.
258-
NavigationTargets::empty()
259-
}
260-
None => first.navigation_targets(db),
261-
}
182+
let full_range = self.full_range(db.upcast());
183+
NavigationTargets::single(NavigationTarget {
184+
file: full_range.file(),
185+
focus_range: self.focus_range(db.upcast()).unwrap_or(full_range).range(),
186+
full_range: full_range.range(),
187+
})
262188
}
263189
}

0 commit comments

Comments
 (0)