diff --git a/crates/red_knot/src/lint.rs b/crates/red_knot/src/lint.rs index 800287478997e..7ca29f5f2d9c5 100644 --- a/crates/red_knot/src/lint.rs +++ b/crates/red_knot/src/lint.rs @@ -9,10 +9,13 @@ use ruff_python_ast::{ModModule, StringLiteral}; use crate::cache::KeyValueCache; use crate::db::{LintDb, LintJar, QueryResult}; use crate::files::FileId; +use crate::module::ModuleName; use crate::parse::{parse, Parsed}; use crate::source::{source_text, Source}; -use crate::symbols::{symbol_table, Definition, SymbolId, SymbolTable}; -use crate::types::{infer_symbol_type, Type}; +use crate::symbols::{ + resolve_global_symbol, symbol_table, Definition, GlobalSymbolId, SymbolId, SymbolTable, +}; +use crate::types::{infer_definition_type, infer_symbol_type, Type}; #[tracing::instrument(level = "debug", skip(db))] pub(crate) fn lint_syntax(db: &dyn LintDb, file_id: FileId) -> QueryResult { @@ -90,6 +93,7 @@ pub(crate) fn lint_semantic(db: &dyn LintDb, file_id: FileId) -> QueryResult QueryResult<()> { Ok(()) } +fn lint_bad_overrides(context: &SemanticLintContext) -> QueryResult<()> { + // TODO we should have a special marker on the real typing module (from typeshed) so if you + // have your own "typing" module in your project, we don't consider it THE typing module (and + // same for other stdlib modules that our lint rules care about) + let Some(typing_override) = + resolve_global_symbol(context.db.upcast(), ModuleName::new("typing"), "override")? + else { + // TODO once we bundle typeshed, this should be unreachable!() + return Ok(()); + }; + + // TODO we should maybe index definitions by type instead of iterating all, or else iterate all + // just once, match, and branch to all lint rules that care about a type of definition + for (symbol, definition) in context.symbols().all_definitions() { + if !matches!(definition, Definition::FunctionDef(_)) { + continue; + } + let ty = infer_definition_type( + context.db.upcast(), + GlobalSymbolId { + file_id: context.file_id, + symbol_id: symbol, + }, + definition.clone(), + )?; + let Type::Function(func) = ty else { + unreachable!("type of a FunctionDef should always be a Function"); + }; + let Some(class) = func.get_containing_class(context.db.upcast())? else { + // not a method of a class + continue; + }; + if func.has_decorator(context.db.upcast(), typing_override)? { + let method_name = func.name(context.db.upcast())?; + if class + .get_super_class_member(context.db.upcast(), &method_name)? + .is_none() + { + // TODO should have a qualname() method to support nested classes + context.push_diagnostic( + format!( + "Method {}.{} is decorated with `typing.override` but does not override any base class method", + class.name(context.db.upcast())?, + method_name, + )); + } + } + } + Ok(()) +} + pub struct SemanticLintContext<'a> { file_id: FileId, source: Source, @@ -163,7 +218,13 @@ impl<'a> SemanticLintContext<'a> { } pub fn infer_symbol_type(&self, symbol_id: SymbolId) -> QueryResult { - infer_symbol_type(self.db.upcast(), self.file_id, symbol_id) + infer_symbol_type( + self.db.upcast(), + GlobalSymbolId { + file_id: self.file_id, + symbol_id, + }, + ) } pub fn push_diagnostic(&self, diagnostic: String) { diff --git a/crates/red_knot/src/symbols.rs b/crates/red_knot/src/symbols.rs index 5f00f0d45db94..475c52e3516dd 100644 --- a/crates/red_knot/src/symbols.rs +++ b/crates/red_knot/src/symbols.rs @@ -14,15 +14,14 @@ use ruff_index::{newtype_index, IndexVec}; use ruff_python_ast as ast; use ruff_python_ast::visitor::preorder::PreorderVisitor; -use crate::ast_ids::TypedNodeKey; +use crate::ast_ids::{NodeKey, TypedNodeKey}; use crate::cache::KeyValueCache; use crate::db::{QueryResult, SemanticDb, SemanticJar}; use crate::files::FileId; -use crate::module::ModuleName; +use crate::module::{resolve_module, ModuleName}; use crate::parse::parse; use crate::Name; -#[allow(unreachable_pub)] #[tracing::instrument(level = "debug", skip(db))] pub fn symbol_table(db: &dyn SemanticDb, file_id: FileId) -> QueryResult> { let jar: &SemanticJar = db.jar()?; @@ -33,6 +32,32 @@ pub fn symbol_table(db: &dyn SemanticDb, file_id: FileId) -> QueryResult QueryResult> { + let Some(typing_module) = resolve_module(db, module)? else { + return Ok(None); + }; + let typing_file = typing_module.path(db)?.file(); + let typing_table = symbol_table(db, typing_file)?; + let Some(typing_override) = typing_table.root_symbol_id_by_name(name) else { + return Ok(None); + }; + Ok(Some(GlobalSymbolId { + file_id: typing_file, + symbol_id: typing_override, + })) +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct GlobalSymbolId { + pub(crate) file_id: FileId, + pub(crate) symbol_id: SymbolId, +} + type Map = hashbrown::HashMap; #[newtype_index] @@ -65,7 +90,12 @@ pub(crate) enum ScopeKind { pub(crate) struct Scope { name: Name, kind: ScopeKind, - child_scopes: Vec, + parent: Option, + children: Vec, + /// the definition (e.g. class or function) that created this scope + definition: Option, + /// the symbol (e.g. class or function) that owns this scope + defining_symbol: Option, /// symbol IDs, hashed by symbol name symbols_by_name: Map, } @@ -78,6 +108,14 @@ impl Scope { pub(crate) fn kind(&self) -> ScopeKind { self.kind } + + pub(crate) fn definition(&self) -> Option { + self.definition.clone() + } + + pub(crate) fn defining_symbol(&self) -> Option { + self.defining_symbol + } } #[derive(Debug)] @@ -114,6 +152,10 @@ impl Symbol { self.name.as_str() } + pub(crate) fn scope_id(&self) -> ScopeId { + self.scope_id + } + /// Is the symbol used in its containing scope? pub(crate) fn is_used(&self) -> bool { self.flags.contains(SymbolFlags::IS_USED) @@ -132,6 +174,7 @@ impl Symbol { // TODO storing TypedNodeKey for definitions means we have to search to find them again in the AST; // this is at best O(log n). If looking up definitions is a bottleneck we should look for // alternatives here. +// TODO intern Definitions in SymbolTable and reference using IDs? #[derive(Clone, Debug)] pub(crate) enum Definition { // For the import cases, we don't need reference to any arbitrary AST subtrees (annotations, @@ -140,7 +183,7 @@ pub(crate) enum Definition { // the small amount of information we need from the AST. Import(ImportDefinition), ImportFrom(ImportFromDefinition), - ClassDef(ClassDefinition), + ClassDef(TypedNodeKey), FunctionDef(TypedNodeKey), Assignment(TypedNodeKey), AnnotatedAssignment(TypedNodeKey), @@ -173,12 +216,6 @@ impl ImportFromDefinition { } } -#[derive(Clone, Debug)] -pub(crate) struct ClassDefinition { - pub(crate) node_key: TypedNodeKey, - pub(crate) scope_id: ScopeId, -} - #[derive(Debug, Clone)] pub enum Dependency { Module(ModuleName), @@ -193,7 +230,11 @@ pub enum Dependency { pub struct SymbolTable { scopes_by_id: IndexVec, symbols_by_id: IndexVec, + /// the definitions for each symbol defs: FxHashMap>, + /// map of AST node (e.g. class/function def) to sub-scope it creates + scopes_by_node: FxHashMap, + /// dependencies of this module dependencies: Vec, } @@ -214,12 +255,16 @@ impl SymbolTable { scopes_by_id: IndexVec::new(), symbols_by_id: IndexVec::new(), defs: FxHashMap::default(), + scopes_by_node: FxHashMap::default(), dependencies: Vec::new(), }; table.scopes_by_id.push(Scope { name: Name::new(""), kind: ScopeKind::Module, - child_scopes: Vec::new(), + parent: None, + children: Vec::new(), + definition: None, + defining_symbol: None, symbols_by_name: Map::default(), }); table @@ -260,7 +305,7 @@ impl SymbolTable { } pub(crate) fn child_scope_ids_of(&self, scope_id: ScopeId) -> &[ScopeId] { - &self.scopes_by_id[scope_id].child_scopes + &self.scopes_by_id[scope_id].children } pub(crate) fn child_scopes_of(&self, scope_id: ScopeId) -> ScopeIterator<&[ScopeId]> { @@ -303,6 +348,32 @@ impl SymbolTable { self.symbol_by_name(SymbolTable::root_scope_id(), name) } + pub(crate) fn scope_id_of_symbol(&self, symbol_id: SymbolId) -> ScopeId { + self.symbols_by_id[symbol_id].scope_id + } + + pub(crate) fn scope_of_symbol(&self, symbol_id: SymbolId) -> &Scope { + &self.scopes_by_id[self.scope_id_of_symbol(symbol_id)] + } + + pub(crate) fn parent_scopes( + &self, + scope_id: ScopeId, + ) -> ScopeIterator + '_> { + ScopeIterator { + table: self, + ids: std::iter::successors(Some(scope_id), |scope| self.scopes_by_id[*scope].parent), + } + } + + pub(crate) fn parent_scope(&self, scope_id: ScopeId) -> Option { + self.scopes_by_id[scope_id].parent + } + + pub(crate) fn scope_id_for_node(&self, node_key: &NodeKey) -> ScopeId { + self.scopes_by_node[node_key] + } + pub(crate) fn definitions(&self, symbol_id: SymbolId) -> &[Definition] { self.defs .get(&symbol_id) @@ -316,7 +387,7 @@ impl SymbolTable { .flat_map(|(sym_id, defs)| defs.iter().map(move |def| (*sym_id, def))) } - fn add_or_update_symbol( + pub(crate) fn add_or_update_symbol( &mut self, scope_id: ScopeId, name: &str, @@ -357,15 +428,20 @@ impl SymbolTable { parent_scope_id: ScopeId, name: &str, kind: ScopeKind, + definition: Option, + defining_symbol: Option, ) -> ScopeId { let new_scope_id = self.scopes_by_id.push(Scope { name: Name::new(name), kind, - child_scopes: Vec::new(), + parent: Some(parent_scope_id), + children: Vec::new(), + definition, + defining_symbol, symbols_by_name: Map::default(), }); let parent_scope = &mut self.scopes_by_id[parent_scope_id]; - parent_scope.child_scopes.push(new_scope_id); + parent_scope.children.push(new_scope_id); new_scope_id } @@ -412,20 +488,22 @@ where } } +// TODO maybe get rid of this and just do all data access via methods on ScopeId? pub(crate) struct ScopeIterator<'a, I> { table: &'a SymbolTable, ids: I, } +/// iterate (`ScopeId`, `Scope`) pairs for given `ScopeId` iterator impl<'a, I> Iterator for ScopeIterator<'a, I> where I: Iterator, { - type Item = &'a Scope; + type Item = (ScopeId, &'a Scope); fn next(&mut self) -> Option { let id = self.ids.next()?; - Some(&self.table.scopes_by_id[id]) + Some((id, &self.table.scopes_by_id[id])) } fn size_hint(&self) -> (usize, Option) { @@ -441,7 +519,7 @@ where { fn next_back(&mut self) -> Option { let id = self.ids.next_back()?; - Some(&self.table.scopes_by_id[id]) + Some((id, &self.table.scopes_by_id[id])) } } @@ -472,8 +550,16 @@ impl SymbolTableBuilder { symbol_id } - fn push_scope(&mut self, name: &str, kind: ScopeKind) -> ScopeId { - let scope_id = self.table.add_child_scope(self.cur_scope(), name, kind); + fn push_scope( + &mut self, + name: &str, + kind: ScopeKind, + definition: Option, + defining_symbol: Option, + ) -> ScopeId { + let scope_id = + self.table + .add_child_scope(self.cur_scope(), name, kind, definition, defining_symbol); self.scopes.push(scope_id); scope_id } @@ -491,14 +577,20 @@ impl SymbolTableBuilder { .expect("Scope stack should never be empty") } + fn record_scope_for_node(&mut self, node_key: NodeKey, scope_id: ScopeId) { + self.table.scopes_by_node.insert(node_key, scope_id); + } + fn with_type_params( &mut self, name: &str, params: &Option>, + definition: Option, + defining_symbol: Option, nested: impl FnOnce(&mut Self) -> ScopeId, ) -> ScopeId { if let Some(type_params) = params { - self.push_scope(name, ScopeKind::Annotation); + self.push_scope(name, ScopeKind::Annotation, definition, defining_symbol); for type_param in &type_params.type_params { let name = match type_param { ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, .. }) => name, @@ -539,27 +631,50 @@ impl PreorderVisitor<'_> for SymbolTableBuilder { // TODO need to capture more definition statements here match stmt { ast::Stmt::ClassDef(node) => { - let scope_id = self.with_type_params(&node.name, &node.type_params, |builder| { - let scope_id = builder.push_scope(&node.name, ScopeKind::Class); - ast::visitor::preorder::walk_stmt(builder, stmt); - builder.pop_scope(); - scope_id - }); - let def = Definition::ClassDef(ClassDefinition { - node_key: TypedNodeKey::from_node(node), - scope_id, - }); - self.add_or_update_symbol_with_def(&node.name, def); + let node_key = TypedNodeKey::from_node(node); + let def = Definition::ClassDef(node_key.clone()); + let symbol_id = self.add_or_update_symbol_with_def(&node.name, def.clone()); + let scope_id = self.with_type_params( + &node.name, + &node.type_params, + Some(def.clone()), + Some(symbol_id), + |builder| { + let scope_id = builder.push_scope( + &node.name, + ScopeKind::Class, + Some(def.clone()), + Some(symbol_id), + ); + ast::visitor::preorder::walk_stmt(builder, stmt); + builder.pop_scope(); + scope_id + }, + ); + self.record_scope_for_node(*node_key.erased(), scope_id); } ast::Stmt::FunctionDef(node) => { - let def = Definition::FunctionDef(TypedNodeKey::from_node(node)); - self.add_or_update_symbol_with_def(&node.name, def); - self.with_type_params(&node.name, &node.type_params, |builder| { - let scope_id = builder.push_scope(&node.name, ScopeKind::Function); - ast::visitor::preorder::walk_stmt(builder, stmt); - builder.pop_scope(); - scope_id - }); + let node_key = TypedNodeKey::from_node(node); + let def = Definition::FunctionDef(node_key.clone()); + let symbol_id = self.add_or_update_symbol_with_def(&node.name, def.clone()); + let scope_id = self.with_type_params( + &node.name, + &node.type_params, + Some(def.clone()), + Some(symbol_id), + |builder| { + let scope_id = builder.push_scope( + &node.name, + ScopeKind::Function, + Some(def.clone()), + Some(symbol_id), + ); + ast::visitor::preorder::walk_stmt(builder, stmt); + builder.pop_scope(); + scope_id + }, + ); + self.record_scope_for_node(*node_key.erased(), scope_id); } ast::Stmt::Import(ast::StmtImport { names, .. }) => { for alias in names { @@ -933,7 +1048,7 @@ mod tests { let mut table = SymbolTable::new(); let root_scope_id = SymbolTable::root_scope_id(); let foo_symbol_top = table.add_or_update_symbol(root_scope_id, "foo", SymbolFlags::empty()); - let c_scope = table.add_child_scope(root_scope_id, "C", ScopeKind::Class); + let c_scope = table.add_child_scope(root_scope_id, "C", ScopeKind::Class, None, None); let foo_symbol_inner = table.add_or_update_symbol(c_scope, "foo", SymbolFlags::empty()); assert_ne!(foo_symbol_top, foo_symbol_inner); } diff --git a/crates/red_knot/src/types.rs b/crates/red_knot/src/types.rs index c10583757b959..cb17803521d74 100644 --- a/crates/red_knot/src/types.rs +++ b/crates/red_knot/src/types.rs @@ -1,18 +1,16 @@ #![allow(dead_code)] - -use rustc_hash::FxHashMap; - -pub(crate) use infer::infer_symbol_type; -use ruff_index::{newtype_index, IndexVec}; - use crate::ast_ids::NodeKey; use crate::db::{QueryResult, SemanticDb, SemanticJar}; use crate::files::FileId; -use crate::symbols::{symbol_table, ScopeId, SymbolId}; +use crate::symbols::{symbol_table, GlobalSymbolId, ScopeId, ScopeKind, SymbolId}; use crate::{FxDashMap, FxIndexSet, Name}; +use ruff_index::{newtype_index, IndexVec}; +use rustc_hash::FxHashMap; pub(crate) mod infer; +pub(crate) use infer::{infer_definition_type, infer_symbol_type}; + /// unique ID for a type #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub enum Type { @@ -82,10 +80,10 @@ impl TypeStore { self.modules.remove(&file_id); } - pub fn cache_symbol_type(&self, file_id: FileId, symbol_id: SymbolId, ty: Type) { - self.add_or_get_module(file_id) + pub fn cache_symbol_type(&self, symbol: GlobalSymbolId, ty: Type) { + self.add_or_get_module(symbol.file_id) .symbol_types - .insert(symbol_id, ty); + .insert(symbol.symbol_id, ty); } pub fn cache_node_type(&self, file_id: FileId, node_key: NodeKey, ty: Type) { @@ -94,10 +92,10 @@ impl TypeStore { .insert(node_key, ty); } - pub fn get_cached_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Option { - self.try_get_module(file_id)? + pub fn get_cached_symbol_type(&self, symbol: GlobalSymbolId) -> Option { + self.try_get_module(symbol.file_id)? .symbol_types - .get(&symbol_id) + .get(&symbol.symbol_id) .copied() } @@ -122,9 +120,16 @@ impl TypeStore { self.modules.get(&file_id) } - fn add_function(&self, file_id: FileId, name: &str, decorators: Vec) -> FunctionTypeId { + fn add_function( + &self, + file_id: FileId, + name: &str, + symbol_id: SymbolId, + scope_id: ScopeId, + decorators: Vec, + ) -> FunctionTypeId { self.add_or_get_module(file_id) - .add_function(name, decorators) + .add_function(name, symbol_id, scope_id, decorators) } fn add_class( @@ -257,6 +262,80 @@ pub struct FunctionTypeId { func_id: ModuleFunctionTypeId, } +impl FunctionTypeId { + fn function(self, db: &dyn SemanticDb) -> QueryResult { + let jar: &SemanticJar = db.jar()?; + Ok(jar.type_store.get_function(self)) + } + + pub(crate) fn name(self, db: &dyn SemanticDb) -> QueryResult { + Ok(self.function(db)?.name().into()) + } + + pub(crate) fn global_symbol(self, db: &dyn SemanticDb) -> QueryResult { + Ok(GlobalSymbolId { + file_id: self.file(), + symbol_id: self.symbol(db)?, + }) + } + + pub(crate) fn file(self) -> FileId { + self.file_id + } + + pub(crate) fn symbol(self, db: &dyn SemanticDb) -> QueryResult { + let FunctionType { symbol_id, .. } = *self.function(db)?; + Ok(symbol_id) + } + + pub(crate) fn get_containing_class( + self, + db: &dyn SemanticDb, + ) -> QueryResult> { + let table = symbol_table(db, self.file_id)?; + let FunctionType { symbol_id, .. } = *self.function(db)?; + let scope_id = symbol_id.symbol(&table).scope_id(); + let scope = scope_id.scope(&table); + if !matches!(scope.kind(), ScopeKind::Class) { + return Ok(None); + }; + let Some(def) = scope.definition() else { + return Ok(None); + }; + let Some(symbol_id) = scope.defining_symbol() else { + return Ok(None); + }; + let Type::Class(class) = infer_definition_type( + db, + GlobalSymbolId { + file_id: self.file_id, + symbol_id, + }, + def, + )? + else { + return Ok(None); + }; + Ok(Some(class)) + } + + pub(crate) fn has_decorator( + self, + db: &dyn SemanticDb, + decorator_symbol: GlobalSymbolId, + ) -> QueryResult { + for deco_ty in self.function(db)?.decorators() { + let Type::Function(deco_func) = deco_ty else { + continue; + }; + if deco_func.global_symbol(db)? == decorator_symbol { + return Ok(true); + } + } + Ok(false) + } +} + #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] pub struct ClassTypeId { file_id: FileId, @@ -264,14 +343,47 @@ pub struct ClassTypeId { } impl ClassTypeId { - fn get_own_class_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult> { + fn class(self, db: &dyn SemanticDb) -> QueryResult { let jar: &SemanticJar = db.jar()?; + Ok(jar.type_store.get_class(self)) + } + + pub(crate) fn name(self, db: &dyn SemanticDb) -> QueryResult { + Ok(self.class(db)?.name().into()) + } + + pub(crate) fn get_super_class_member( + self, + db: &dyn SemanticDb, + name: &Name, + ) -> QueryResult> { + // TODO we should linearize the MRO instead of doing this recursively + let class = self.class(db)?; + for base in class.bases() { + if let Type::Class(base) = base { + if let Some(own_member) = base.get_own_class_member(db, name)? { + return Ok(Some(own_member)); + } + if let Some(base_member) = base.get_super_class_member(db, name)? { + return Ok(Some(base_member)); + } + } + } + Ok(None) + } + fn get_own_class_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult> { // TODO: this should distinguish instance-only members (e.g. `x: int`) and not return them - let ClassType { scope_id, .. } = *jar.type_store.get_class(self); + let ClassType { scope_id, .. } = *self.class(db)?; let table = symbol_table(db, self.file_id)?; if let Some(symbol_id) = table.symbol_id_by_name(scope_id, name) { - Ok(Some(infer_symbol_type(db, self.file_id, symbol_id)?)) + Ok(Some(infer_symbol_type( + db, + GlobalSymbolId { + file_id: self.file_id, + symbol_id, + }, + )?)) } else { Ok(None) } @@ -334,9 +446,17 @@ impl ModuleTypeStore { } } - fn add_function(&mut self, name: &str, decorators: Vec) -> FunctionTypeId { + fn add_function( + &mut self, + name: &str, + symbol_id: SymbolId, + scope_id: ScopeId, + decorators: Vec, + ) -> FunctionTypeId { let func_id = self.functions.push(FunctionType { name: Name::new(name), + symbol_id, + scope_id, decorators, }); FunctionTypeId { @@ -436,7 +556,7 @@ pub(crate) struct ClassType { /// Name of the class at definition name: Name, /// `ScopeId` of the class body - pub(crate) scope_id: ScopeId, + scope_id: ScopeId, /// Types of all class bases bases: Vec, } @@ -453,7 +573,13 @@ impl ClassType { #[derive(Debug)] pub(crate) struct FunctionType { + /// name of the function at definition name: Name, + /// symbol which this function is a definition of + symbol_id: SymbolId, + /// scope of this function's body + scope_id: ScopeId, + /// types of all decorators on this function decorators: Vec, } @@ -462,7 +588,11 @@ impl FunctionType { self.name.as_str() } - fn decorators(&self) -> &[Type] { + fn scope_id(&self) -> ScopeId { + self.scope_id + } + + pub(crate) fn decorators(&self) -> &[Type] { self.decorators.as_slice() } } @@ -493,12 +623,12 @@ impl UnionType { // directly in intersections rather than as a separate type. This sacrifices some efficiency in the // case where a Not appears outside an intersection (unclear when that could even happen, but we'd // have to represent it as a single-element intersection if it did) in exchange for better -// efficiency in the not-within-intersection case. +// efficiency in the within-intersection case. #[derive(Debug)] pub(crate) struct IntersectionType { // the intersection type includes only values in all of these types positive: FxIndexSet, - // negated elements of the intersection, e.g. + // the intersection type does not include any value in any of these types negative: FxIndexSet, } @@ -530,7 +660,7 @@ mod tests { use std::path::Path; use crate::files::Files; - use crate::symbols::SymbolTable; + use crate::symbols::{SymbolFlags, SymbolTable}; use crate::types::{Type, TypeStore}; use crate::FxIndexSet; @@ -550,7 +680,20 @@ mod tests { let store = TypeStore::default(); let files = Files::default(); let file_id = files.intern(Path::new("/foo")); - let id = store.add_function(file_id, "func", vec![Type::Unknown]); + let mut table = SymbolTable::new(); + let func_symbol = table.add_or_update_symbol( + SymbolTable::root_scope_id(), + "func", + SymbolFlags::IS_DEFINED, + ); + + let id = store.add_function( + file_id, + "func", + func_symbol, + SymbolTable::root_scope_id(), + vec![Type::Unknown], + ); assert_eq!(store.get_function(id).name(), "func"); assert_eq!(store.get_function(id).decorators(), vec![Type::Unknown]); let func = Type::Function(id); diff --git a/crates/red_knot/src/types/infer.rs b/crates/red_knot/src/types/infer.rs index 43ab0b131f70a..a34f367fb55a2 100644 --- a/crates/red_knot/src/types/infer.rs +++ b/crates/red_knot/src/types/infer.rs @@ -5,33 +5,47 @@ use ruff_python_ast::AstNode; use crate::db::{QueryResult, SemanticDb, SemanticJar}; -use crate::module::{resolve_module, ModuleName}; +use crate::module::ModuleName; use crate::parse::parse; -use crate::symbols::{symbol_table, ClassDefinition, Definition, ImportFromDefinition, SymbolId}; +use crate::symbols::{ + resolve_global_symbol, symbol_table, Definition, GlobalSymbolId, ImportFromDefinition, +}; use crate::types::Type; use crate::FileId; // FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`. #[tracing::instrument(level = "trace", skip(db))] -pub fn infer_symbol_type( - db: &dyn SemanticDb, - file_id: FileId, - symbol_id: SymbolId, -) -> QueryResult { - let symbols = symbol_table(db, file_id)?; - let defs = symbols.definitions(symbol_id); - +pub fn infer_symbol_type(db: &dyn SemanticDb, symbol: GlobalSymbolId) -> QueryResult { + let symbols = symbol_table(db, symbol.file_id)?; + let defs = symbols.definitions(symbol.symbol_id); let jar: &SemanticJar = db.jar()?; - let type_store = &jar.type_store; - if let Some(ty) = type_store.get_cached_symbol_type(file_id, symbol_id) { + if let Some(ty) = jar.type_store.get_cached_symbol_type(symbol) { return Ok(ty); } // TODO handle multiple defs, conditional defs... assert_eq!(defs.len(), 1); - let ty = match &defs[0] { + let ty = infer_definition_type(db, symbol, defs[0].clone())?; + + jar.type_store.cache_symbol_type(symbol, ty); + + // TODO record dependencies + Ok(ty) +} + +#[tracing::instrument(level = "trace", skip(db))] +pub fn infer_definition_type( + db: &dyn SemanticDb, + symbol: GlobalSymbolId, + definition: Definition, +) -> QueryResult { + let jar: &SemanticJar = db.jar()?; + let type_store = &jar.type_store; + let file_id = symbol.file_id; + + match definition { Definition::ImportFrom(ImportFromDefinition { module, name, @@ -40,24 +54,19 @@ pub fn infer_symbol_type( // TODO relative imports assert!(matches!(level, 0)); let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports")); - if let Some(module) = resolve_module(db, module_name)? { - let remote_file_id = module.path(db)?.file(); - let remote_symbols = symbol_table(db, remote_file_id)?; - if let Some(remote_symbol_id) = remote_symbols.root_symbol_id_by_name(name) { - infer_symbol_type(db, remote_file_id, remote_symbol_id)? - } else { - Type::Unknown - } + if let Some(remote_symbol) = resolve_global_symbol(db, module_name, &name)? { + infer_symbol_type(db, remote_symbol) } else { - Type::Unknown + Ok(Type::Unknown) } } - Definition::ClassDef(ClassDefinition { node_key, scope_id }) => { + Definition::ClassDef(node_key) => { if let Some(ty) = type_store.get_cached_node_type(file_id, node_key.erased()) { - ty + Ok(ty) } else { let parsed = parse(db.upcast(), file_id)?; let ast = parsed.ast(); + let table = symbol_table(db, file_id)?; let node = node_key.resolve_unwrap(ast.as_any_node_ref()); let mut bases = Vec::with_capacity(node.bases().len()); @@ -65,19 +74,19 @@ pub fn infer_symbol_type( for base in node.bases() { bases.push(infer_expr_type(db, file_id, base)?); } - - let ty = - Type::Class(type_store.add_class(file_id, &node.name.id, *scope_id, bases)); + let scope_id = table.scope_id_for_node(node_key.erased()); + let ty = Type::Class(type_store.add_class(file_id, &node.name.id, scope_id, bases)); type_store.cache_node_type(file_id, *node_key.erased(), ty); - ty + Ok(ty) } } Definition::FunctionDef(node_key) => { if let Some(ty) = type_store.get_cached_node_type(file_id, node_key.erased()) { - ty + Ok(ty) } else { let parsed = parse(db.upcast(), file_id)?; let ast = parsed.ast(); + let table = symbol_table(db, file_id)?; let node = node_key .resolve(ast.as_any_node_ref()) .expect("node key should resolve"); @@ -87,12 +96,18 @@ pub fn infer_symbol_type( .iter() .map(|decorator| infer_expr_type(db, file_id, &decorator.expression)) .collect::>()?; - + let scope_id = table.scope_id_for_node(node_key.erased()); let ty = type_store - .add_function(file_id, &node.name.id, decorator_tys) + .add_function( + file_id, + &node.name.id, + symbol.symbol_id, + scope_id, + decorator_tys, + ) .into(); type_store.cache_node_type(file_id, *node_key.erased(), ty); - ty + Ok(ty) } } Definition::Assignment(node_key) => { @@ -100,15 +115,10 @@ pub fn infer_symbol_type( let ast = parsed.ast(); let node = node_key.resolve_unwrap(ast.as_any_node_ref()); // TODO handle unpacking assignment correctly - infer_expr_type(db, file_id, &node.value)? + infer_expr_type(db, file_id, &node.value) } _ => todo!("other kinds of definitions"), - }; - - type_store.cache_symbol_type(file_id, symbol_id, ty); - - // TODO record dependencies - Ok(ty) + } } fn infer_expr_type(db: &dyn SemanticDb, file_id: FileId, expr: &ast::Expr) -> QueryResult { @@ -116,8 +126,9 @@ fn infer_expr_type(db: &dyn SemanticDb, file_id: FileId, expr: &ast::Expr) -> Qu let symbols = symbol_table(db, file_id)?; match expr { ast::Expr::Name(name) => { + // TODO look up in the correct scope, don't assume global if let Some(symbol_id) = symbols.root_symbol_id_by_name(&name.id) { - infer_symbol_type(db, file_id, symbol_id) + infer_symbol_type(db, GlobalSymbolId { file_id, symbol_id }) } else { Ok(Type::Unknown) } @@ -133,7 +144,7 @@ mod tests { use crate::module::{ resolve_module, set_module_search_paths, ModuleName, ModuleSearchPath, ModuleSearchPathKind, }; - use crate::symbols::symbol_table; + use crate::symbols::{symbol_table, GlobalSymbolId}; use crate::types::{infer_symbol_type, Type}; use crate::Name; @@ -180,7 +191,13 @@ mod tests { .root_symbol_id_by_name("E") .expect("E symbol should be found"); - let ty = infer_symbol_type(db, a_file, e_sym)?; + let ty = infer_symbol_type( + db, + GlobalSymbolId { + file_id: a_file, + symbol_id: e_sym, + }, + )?; let jar = HasJar::::jar(db)?; assert!(matches!(ty, Type::Class(_))); @@ -205,7 +222,13 @@ mod tests { .root_symbol_id_by_name("Sub") .expect("Sub symbol should be found"); - let ty = infer_symbol_type(db, file, sym)?; + let ty = infer_symbol_type( + db, + GlobalSymbolId { + file_id: file, + symbol_id: sym, + }, + )?; let Type::Class(class_id) = ty else { panic!("Sub is not a Class") @@ -240,7 +263,13 @@ mod tests { .root_symbol_id_by_name("C") .expect("C symbol should be found"); - let ty = infer_symbol_type(db, file, sym)?; + let ty = infer_symbol_type( + db, + GlobalSymbolId { + file_id: file, + symbol_id: sym, + }, + )?; let Type::Class(class_id) = ty else { panic!("C is not a Class");