Skip to content

Commit

Permalink
Use Cow<'db, Name> in symbol table`
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaReiser committed Jun 29, 2024
1 parent 1695155 commit df99bef
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 41 deletions.
12 changes: 6 additions & 6 deletions crates/red_knot_python_semantic/src/semantic_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type SymbolMap = hashbrown::HashMap<ScopedSymbolId, (), ()>;
///
/// Prefer using [`symbol_table`] when working with symbols from a single scope.
#[salsa::tracked(return_ref, no_eq)]
pub(crate) fn semantic_index(db: &dyn Db, file: VfsFile) -> SemanticIndex {
pub(crate) fn semantic_index<'db>(db: &'db dyn Db, file: VfsFile) -> SemanticIndex<'db> {
let _ = tracing::trace_span!("semantic_index", file = ?file.debug(db.upcast())).enter();

let parsed = parsed_module(db.upcast(), file);
Expand All @@ -42,7 +42,7 @@ pub(crate) fn semantic_index(db: &dyn Db, file: VfsFile) -> SemanticIndex {
/// Salsa can avoid invalidating dependent queries if this scope's symbol table
/// is unchanged.
#[salsa::tracked]
pub(crate) fn symbol_table<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc<SymbolTable> {
pub(crate) fn symbol_table<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc<SymbolTable<'db>> {
let _ = tracing::trace_span!("symbol_table", scope = ?scope.debug(db)).enter();
let index = semantic_index(db, scope.file(db));

Expand Down Expand Up @@ -72,9 +72,9 @@ pub fn public_symbol<'db>(

/// The symbol tables for an entire file.
#[derive(Debug)]
pub struct SemanticIndex {
pub struct SemanticIndex<'db> {
/// List of all symbol tables in this file, indexed by scope.
symbol_tables: IndexVec<FileScopeId, Arc<SymbolTable>>,
symbol_tables: IndexVec<FileScopeId, Arc<SymbolTable<'db>>>,

/// List of all scopes in this file.
scopes: IndexVec<FileScopeId, Scope>,
Expand All @@ -94,7 +94,7 @@ pub struct SemanticIndex {
scope_nodes: IndexVec<FileScopeId, NodeWithScopeId>,
}

impl SemanticIndex {
impl<'db> SemanticIndex<'db> {
/// Returns the symbol table for a specific scope.
///
/// Use the Salsa cached [`symbol_table`] query if you only need the
Expand Down Expand Up @@ -297,7 +297,7 @@ mod tests {
TestCase { db, file }
}

fn names(table: &SymbolTable) -> Vec<&str> {
fn names<'db>(table: &'db SymbolTable<'db>) -> Vec<&'db str> {
table
.symbols()
.map(|symbol| symbol.name().as_str())
Expand Down
35 changes: 18 additions & 17 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::sync::Arc;

use rustc_hash::FxHashMap;
Expand Down Expand Up @@ -28,7 +29,7 @@ pub(super) struct SemanticIndexBuilder<'a> {

// Semantic Index fields
scopes: IndexVec<FileScopeId, Scope>,
symbol_tables: IndexVec<FileScopeId, SymbolTableBuilder>,
symbol_tables: IndexVec<FileScopeId, SymbolTableBuilder<'a>>,
ast_ids: IndexVec<FileScopeId, AstIdsBuilder>,
expression_scopes: FxHashMap<NodeKey, FileScopeId>,
scope_nodes: IndexVec<FileScopeId, NodeWithScopeId>,
Expand Down Expand Up @@ -114,7 +115,7 @@ impl<'a> SemanticIndexBuilder<'a> {
id
}

fn current_symbol_table(&mut self) -> &mut SymbolTableBuilder {
fn current_symbol_table(&mut self) -> &mut SymbolTableBuilder<'a> {
let scope_id = self.current_scope();
&mut self.symbol_tables[scope_id]
}
Expand All @@ -124,15 +125,15 @@ impl<'a> SemanticIndexBuilder<'a> {
&mut self.ast_ids[scope_id]
}

fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopedSymbolId {
fn add_or_update_symbol(&mut self, name: Cow<'a, Name>, flags: SymbolFlags) -> ScopedSymbolId {
let symbol_table = self.current_symbol_table();

symbol_table.add_or_update_symbol(name, flags, None)
}

fn add_or_update_symbol_with_definition(
&mut self,
name: Name,
name: Cow<'a, Name>,
definition: Definition,
) -> ScopedSymbolId {
let symbol_table = self.current_symbol_table();
Expand All @@ -143,7 +144,7 @@ impl<'a> SemanticIndexBuilder<'a> {
fn with_type_params(
&mut self,
name: &Name,
with_params: &WithTypeParams,
with_params: &WithTypeParams<'a>,
defining_symbol: FileSymbolId,
nested: impl FnOnce(&mut Self) -> FileScopeId,
) -> FileScopeId {
Expand All @@ -167,7 +168,7 @@ impl<'a> SemanticIndexBuilder<'a> {
ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { name, .. }) => name,
ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { name, .. }) => name,
};
self.add_or_update_symbol(name.id.clone(), SymbolFlags::IS_DEFINED);
self.add_or_update_symbol(Cow::Borrowed(&name.id), SymbolFlags::IS_DEFINED);
}
}

Expand All @@ -180,7 +181,7 @@ impl<'a> SemanticIndexBuilder<'a> {
nested_scope
}

pub(super) fn build(mut self) -> SemanticIndex {
pub(super) fn build(mut self) -> SemanticIndex<'a> {
let module = self.module;
self.visit_body(module.suite());

Expand Down Expand Up @@ -218,8 +219,8 @@ impl<'a> SemanticIndexBuilder<'a> {
}
}

impl Visitor<'_> for SemanticIndexBuilder<'_> {
fn visit_stmt(&mut self, stmt: &ast::Stmt) {
impl<'db> Visitor<'db> for SemanticIndexBuilder<'db> {
fn visit_stmt(&mut self, stmt: &'db ast::Stmt) {
let module = self.module;
#[allow(unsafe_code)]
let statement_id = unsafe {
Expand All @@ -238,7 +239,7 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
let scope = self.current_scope();
let symbol = FileSymbolId::new(
scope,
self.add_or_update_symbol_with_definition(name.clone(), definition),
self.add_or_update_symbol_with_definition(Cow::Borrowed(name), definition),
);

self.with_type_params(
Expand Down Expand Up @@ -276,7 +277,7 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
let scope = self.current_scope();
let id = FileSymbolId::new(
self.current_scope(),
self.add_or_update_symbol_with_definition(name.clone(), definition),
self.add_or_update_symbol_with_definition(Cow::Borrowed(name), definition),
);
self.with_type_params(
&name,
Expand Down Expand Up @@ -305,9 +306,9 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
ast::Stmt::Import(ast::StmtImport { names, .. }) => {
for (i, alias) in names.iter().enumerate() {
let symbol_name = if let Some(asname) = &alias.asname {
asname.id.clone()
Cow::Borrowed(&asname.id)
} else {
Name::new(alias.name.id.split('.').next().unwrap())
Cow::Owned(Name::new(alias.name.id.split('.').next().unwrap()))
};

let def = Definition::Import(ImportDefinition {
Expand All @@ -333,7 +334,7 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
import_id: ScopeImportFromId(statement_id),
name: u32::try_from(i).unwrap(),
});
self.add_or_update_symbol_with_definition(symbol_name.clone(), def);
self.add_or_update_symbol_with_definition(Cow::Borrowed(symbol_name), def);
}
}
ast::Stmt::Assign(node) => {
Expand All @@ -352,7 +353,7 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
}
}

fn visit_expr(&mut self, expr: &'_ ast::Expr) {
fn visit_expr(&mut self, expr: &'db ast::Expr) {
let module = self.module;
#[allow(unsafe_code)]
let expression_id = unsafe {
Expand All @@ -374,10 +375,10 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
};
match self.current_definition {
Some(definition) if flags.contains(SymbolFlags::IS_DEFINED) => {
self.add_or_update_symbol_with_definition(id.clone(), definition);
self.add_or_update_symbol_with_definition(Cow::Borrowed(id), definition);
}
_ => {
self.add_or_update_symbol(id.clone(), flags);
self.add_or_update_symbol(Cow::Borrowed(id), flags);
}
}

Expand Down
33 changes: 17 additions & 16 deletions crates/red_knot_python_semantic/src/semantic_index/symbol.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::hash::{Hash, Hasher};
use std::ops::Range;

Expand All @@ -15,15 +16,15 @@ use ruff_index::{newtype_index, IndexVec};
use ruff_python_ast::name::Name;

#[derive(Eq, PartialEq, Debug)]
pub struct Symbol {
name: Name,
pub struct Symbol<'db> {
name: Cow<'db, Name>,
flags: SymbolFlags,
/// The nodes that define this symbol, in source order.
definitions: SmallVec<[Definition; 4]>,
}

impl Symbol {
fn new(name: Name, definition: Option<Definition>) -> Self {
impl<'db> Symbol<'db> {
fn new(name: Cow<'db, Name>, definition: Option<Definition>) -> Self {
Self {
name,
flags: SymbolFlags::empty(),
Expand Down Expand Up @@ -255,15 +256,15 @@ pub enum ScopeKind {

/// Symbol table for a specific [`Scope`].
#[derive(Debug)]
pub struct SymbolTable {
pub struct SymbolTable<'db> {
/// The symbols in this scope.
symbols: IndexVec<ScopedSymbolId, Symbol>,
symbols: IndexVec<ScopedSymbolId, Symbol<'db>>,

/// The symbols indexed by name.
symbols_by_name: SymbolMap,
}

impl SymbolTable {
impl<'db> SymbolTable<'db> {
fn new() -> Self {
Self {
symbols: IndexVec::new(),
Expand All @@ -279,7 +280,7 @@ impl SymbolTable {
&self.symbols[symbol_id.into()]
}

pub(crate) fn symbol_ids(&self) -> impl Iterator<Item = ScopedSymbolId> {
pub(crate) fn symbol_ids(&self) -> impl Iterator<Item = ScopedSymbolId> + 'db {
self.symbols.indices()
}

Expand Down Expand Up @@ -313,21 +314,21 @@ impl SymbolTable {
}
}

impl PartialEq for SymbolTable {
impl PartialEq for SymbolTable<'_> {
fn eq(&self, other: &Self) -> bool {
// We don't need to compare the symbols_by_name because the name is already captured in `Symbol`.
self.symbols == other.symbols
}
}

impl Eq for SymbolTable {}
impl Eq for SymbolTable<'_> {}

#[derive(Debug)]
pub(super) struct SymbolTableBuilder {
table: SymbolTable,
pub(super) struct SymbolTableBuilder<'db> {
table: SymbolTable<'db>,
}

impl SymbolTableBuilder {
impl<'db> SymbolTableBuilder<'db> {
pub(super) fn new() -> Self {
Self {
table: SymbolTable::new(),
Expand All @@ -336,7 +337,7 @@ impl SymbolTableBuilder {

pub(super) fn add_or_update_symbol(
&mut self,
name: Name,
name: Cow<'db, Name>,
flags: SymbolFlags,
definition: Option<Definition>,
) -> ScopedSymbolId {
Expand All @@ -345,7 +346,7 @@ impl SymbolTableBuilder {
.table
.symbols_by_name
.raw_entry_mut()
.from_hash(hash, |id| self.table.symbols[*id].name() == &name);
.from_hash(hash, |id| self.table.symbols[*id].name() == &*name);

match entry {
RawEntryMut::Occupied(entry) => {
Expand All @@ -371,7 +372,7 @@ impl SymbolTableBuilder {
}
}

pub(super) fn finish(mut self) -> SymbolTable {
pub(super) fn finish(mut self) -> SymbolTable<'db> {
self.table.shrink_to_fit();
self.table
}
Expand Down
4 changes: 2 additions & 2 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ pub(super) struct TypeInferenceBuilder<'a> {
db: &'a dyn Db,

// Cached lookups
index: &'a SemanticIndex,
index: &'a SemanticIndex<'a>,
scope: ScopeId<'a>,
file_scope_id: FileScopeId,
file_id: VfsFile,
symbol_table: Arc<SymbolTable>,
symbol_table: Arc<SymbolTable<'a>>,

/// The type inference results
types: TypeInference<'a>,
Expand Down

0 comments on commit df99bef

Please sign in to comment.