diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_type_checking/quote.py b/crates/ruff_linter/resources/test/fixtures/flake8_type_checking/quote.py new file mode 100644 index 00000000000000..1c3ba3c84642cb --- /dev/null +++ b/crates/ruff_linter/resources/test/fixtures/flake8_type_checking/quote.py @@ -0,0 +1,33 @@ +def f(): + from pandas import DataFrame + + def baz() -> DataFrame: + ... + + +def f(): + from pandas import DataFrame + + def baz() -> DataFrame[int]: + ... + + +def f(): + from pandas import DataFrame + + def baz() -> DataFrame["int"]: + ... + + +def f(): + import pandas as pd + + def baz() -> pd.DataFrame: + ... + + +def f(): + import pandas as pd + + def baz() -> pd.DataFrame.Extra: + ... diff --git a/crates/ruff_linter/src/checkers/ast/mod.rs b/crates/ruff_linter/src/checkers/ast/mod.rs index 6c347af7d18444..14556ee514bbdc 100644 --- a/crates/ruff_linter/src/checkers/ast/mod.rs +++ b/crates/ruff_linter/src/checkers/ast/mod.rs @@ -514,8 +514,10 @@ where .chain(¶meters.kwonlyargs) { if let Some(expr) = ¶meter_with_default.parameter.annotation { - if runtime_annotation || singledispatch { - self.visit_runtime_annotation(expr); + if singledispatch { + self.visit_runtime_required_annotation(expr); + } else if runtime_annotation { + self.visit_runtime_evaluated_annotation(expr); } else { self.visit_annotation(expr); }; @@ -528,7 +530,7 @@ where if let Some(arg) = ¶meters.vararg { if let Some(expr) = &arg.annotation { if runtime_annotation { - self.visit_runtime_annotation(expr); + self.visit_runtime_evaluated_annotation(expr); } else { self.visit_annotation(expr); }; @@ -537,7 +539,7 @@ where if let Some(arg) = ¶meters.kwarg { if let Some(expr) = &arg.annotation { if runtime_annotation { - self.visit_runtime_annotation(expr); + self.visit_runtime_evaluated_annotation(expr); } else { self.visit_annotation(expr); }; @@ -545,7 +547,7 @@ where } for expr in returns { if runtime_annotation { - self.visit_runtime_annotation(expr); + self.visit_runtime_evaluated_annotation(expr); } else { self.visit_annotation(expr); }; @@ -676,40 +678,64 @@ where value, .. }) => { - // If we're in a class or module scope, then the annotation needs to be - // available at runtime. - // See: https://docs.python.org/3/reference/simple_stmts.html#annotated-assignment-statements - let runtime_annotation = if self.semantic.future_annotations() { - self.semantic + enum AnnotationKind { + RuntimeRequired, + RuntimeEvaluated, + TypingOnly, + } + + fn annotation_kind( + semantic: &SemanticModel, + settings: &LinterSettings, + ) -> AnnotationKind { + // If the annotation is in a class, and that class is marked as + // runtime-evaluated, treat the annotation as runtime-required. + // TODO(charlie): We could also include function calls here. + if semantic .current_scope() .kind .as_class() .is_some_and(|class_def| { - flake8_type_checking::helpers::runtime_evaluated_class( + flake8_type_checking::helpers::runtime_required_class( class_def, - &self - .settings - .flake8_type_checking - .runtime_evaluated_base_classes, - &self - .settings - .flake8_type_checking - .runtime_evaluated_decorators, - &self.semantic, + &settings.flake8_type_checking.runtime_evaluated_base_classes, + &settings.flake8_type_checking.runtime_evaluated_decorators, + semantic, ) }) - } else { - matches!( - self.semantic.current_scope().kind, + { + return AnnotationKind::RuntimeRequired; + } + + // If `__future__` annotations are enabled, then annotations are never evaluated + // at runtime, so we can treat them as typing-only. + if semantic.future_annotations() { + return AnnotationKind::TypingOnly; + } + + // Otherwise, if we're in a class or module scope, then the annotation needs to + // be available at runtime. + // See: https://docs.python.org/3/reference/simple_stmts.html#annotated-assignment-statements + if matches!( + semantic.current_scope().kind, ScopeKind::Class(_) | ScopeKind::Module - ) - }; + ) { + return AnnotationKind::RuntimeEvaluated; + } - if runtime_annotation { - self.visit_runtime_annotation(annotation); - } else { - self.visit_annotation(annotation); + AnnotationKind::TypingOnly + } + + match annotation_kind(&self.semantic, self.settings) { + AnnotationKind::RuntimeRequired => { + self.visit_runtime_required_annotation(annotation); + } + AnnotationKind::RuntimeEvaluated => { + self.visit_runtime_evaluated_annotation(annotation); + } + AnnotationKind::TypingOnly => self.visit_annotation(annotation), } + if let Some(expr) = value { if self.semantic.match_typing_expr(annotation, "TypeAlias") { self.visit_type_definition(expr); @@ -1526,10 +1552,18 @@ impl<'a> Checker<'a> { self.semantic.flags = snapshot; } + /// Visit an [`Expr`], and treat it as a runtime-evaluated type annotation. + fn visit_runtime_evaluated_annotation(&mut self, expr: &'a Expr) { + let snapshot = self.semantic.flags; + self.semantic.flags |= SemanticModelFlags::RUNTIME_EVALUATED_ANNOTATION; + self.visit_type_definition(expr); + self.semantic.flags = snapshot; + } + /// Visit an [`Expr`], and treat it as a runtime-required type annotation. - fn visit_runtime_annotation(&mut self, expr: &'a Expr) { + fn visit_runtime_required_annotation(&mut self, expr: &'a Expr) { let snapshot = self.semantic.flags; - self.semantic.flags |= SemanticModelFlags::RUNTIME_ANNOTATION; + self.semantic.flags |= SemanticModelFlags::RUNTIME_REQUIRED_ANNOTATION; self.visit_type_definition(expr); self.semantic.flags = snapshot; } diff --git a/crates/ruff_linter/src/rules/flake8_type_checking/helpers.rs b/crates/ruff_linter/src/rules/flake8_type_checking/helpers.rs index 0a51e151f4703c..096ed058832e0e 100644 --- a/crates/ruff_linter/src/rules/flake8_type_checking/helpers.rs +++ b/crates/ruff_linter/src/rules/flake8_type_checking/helpers.rs @@ -12,28 +12,37 @@ pub(crate) fn is_valid_runtime_import(binding: &Binding, semantic: &SemanticMode binding.context.is_runtime() && binding .references() - .any(|reference_id| semantic.reference(reference_id).context().is_runtime()) + .map(|reference_id| semantic.reference(reference_id)) + .any(|reference| { + // This is like: typing context _or_ a runtime-required type annotation (since + // we're willing to quote it). + !(reference.in_type_checking_block() + || reference.in_typing_only_annotation() + || reference.in_runtime_evaluated_annotation() + || reference.in_complex_string_type_definition() + || reference.in_simple_string_type_definition()) + }) } else { false } } -pub(crate) fn runtime_evaluated_class( +pub(crate) fn runtime_required_class( class_def: &ast::StmtClassDef, base_classes: &[String], decorators: &[String], semantic: &SemanticModel, ) -> bool { - if runtime_evaluated_base_class(class_def, base_classes, semantic) { + if runtime_required_base_class(class_def, base_classes, semantic) { return true; } - if runtime_evaluated_decorators(class_def, decorators, semantic) { + if runtime_required_decorators(class_def, decorators, semantic) { return true; } false } -fn runtime_evaluated_base_class( +fn runtime_required_base_class( class_def: &ast::StmtClassDef, base_classes: &[String], semantic: &SemanticModel, @@ -45,7 +54,7 @@ fn runtime_evaluated_base_class( seen: &mut FxHashSet, ) -> bool { class_def.bases().iter().any(|expr| { - // If the base class is itself runtime-evaluated, then this is too. + // If the base class is itself runtime-required, then this is too. // Ex) `class Foo(BaseModel): ...` if semantic .resolve_call_path(map_subscript(expr)) @@ -58,7 +67,7 @@ fn runtime_evaluated_base_class( return true; } - // If the base class extends a runtime-evaluated class, then this does too. + // If the base class extends a runtime-required class, then this does too. // Ex) `class Bar(BaseModel): ...; class Foo(Bar): ...` if let Some(id) = semantic.lookup_attribute(map_subscript(expr)) { if seen.insert(id) { @@ -86,7 +95,7 @@ fn runtime_evaluated_base_class( inner(class_def, base_classes, semantic, &mut FxHashSet::default()) } -fn runtime_evaluated_decorators( +fn runtime_required_decorators( class_def: &ast::StmtClassDef, decorators: &[String], semantic: &SemanticModel, diff --git a/crates/ruff_linter/src/rules/flake8_type_checking/mod.rs b/crates/ruff_linter/src/rules/flake8_type_checking/mod.rs index 82b24755f4277a..dd494af115e7fd 100644 --- a/crates/ruff_linter/src/rules/flake8_type_checking/mod.rs +++ b/crates/ruff_linter/src/rules/flake8_type_checking/mod.rs @@ -37,6 +37,7 @@ mod tests { #[test_case(Rule::TypingOnlyStandardLibraryImport, Path::new("TCH003.py"))] #[test_case(Rule::TypingOnlyStandardLibraryImport, Path::new("snapshot.py"))] #[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("TCH002.py"))] + #[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("quote.py"))] #[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("singledispatch.py"))] #[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("strict.py"))] #[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("typing_modules_1.py"))] diff --git a/crates/ruff_linter/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs b/crates/ruff_linter/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs index 1ce6336d3b158e..7f2b21990f1f40 100644 --- a/crates/ruff_linter/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs +++ b/crates/ruff_linter/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs @@ -3,9 +3,14 @@ use std::borrow::Cow; use anyhow::Result; use rustc_hash::FxHashMap; -use ruff_diagnostics::{Diagnostic, DiagnosticKind, Fix, FixAvailability, Violation}; +use ruff_diagnostics::{Diagnostic, DiagnosticKind, Edit, Fix, FixAvailability, Violation}; use ruff_macros::{derive_message_formats, violation}; -use ruff_python_semantic::{AnyImport, Binding, Imported, NodeId, ResolvedReferenceId, Scope}; +use ruff_python_ast::Expr; +use ruff_python_codegen::Stylist; +use ruff_python_semantic::{ + AnyImport, Binding, Imported, NodeId, ResolvedReferenceId, Scope, SemanticModel, +}; +use ruff_source_file::Locator; use ruff_text_size::{Ranged, TextRange}; use crate::checkers::ast::Checker; @@ -253,13 +258,19 @@ pub(crate) fn typing_only_runtime_import( }; if binding.context.is_runtime() - && binding.references().all(|reference_id| { - checker - .semantic() - .reference(reference_id) - .context() - .is_typing() - }) + && binding + .references() + .map(|reference_id| checker.semantic().reference(reference_id)) + .all(|reference| { + // All references should be in a typing context _or_ a runtime-evaluated + // annotation (as opposed to a runtime-required annotation), which we can + // quote. + reference.in_type_checking_block() + || reference.in_typing_only_annotation() + || reference.in_runtime_evaluated_annotation() + || reference.in_complex_string_type_definition() + || reference.in_simple_string_type_definition() + }) { let qualified_name = import.qualified_name(); @@ -310,6 +321,7 @@ pub(crate) fn typing_only_runtime_import( let import = ImportBinding { import, reference_id, + binding, range: binding.range(), parent_range: binding.parent_range(checker.semantic()), }; @@ -380,6 +392,8 @@ pub(crate) fn typing_only_runtime_import( struct ImportBinding<'a> { /// The qualified name of the import (e.g., `typing.List` for `from typing import List`). import: AnyImport<'a>, + /// The binding for the imported symbol. + binding: &'a Binding<'a>, /// The first reference to the imported symbol. reference_id: ResolvedReferenceId, /// The trimmed range of the import (e.g., `List` in `from typing import List`). @@ -482,9 +496,87 @@ fn fix_imports(checker: &Checker, node_id: NodeId, imports: &[ImportBinding]) -> checker.source_type, )?; - Ok( - Fix::unsafe_edits(remove_import_edit, add_import_edit.into_edits()).isolate( - Checker::isolation(checker.semantic().parent_statement_id(node_id)), - ), + // Step 3) Quote any runtime usages of the referenced symbol. + let quote_reference_edits = imports + .iter() + .flat_map(|ImportBinding { binding, .. }| { + binding.references.iter().filter_map(|reference_id| { + let reference = checker.semantic().reference(*reference_id); + if reference.context().is_runtime() { + Some(quote_annotation( + reference.node_id()?, + checker.semantic(), + checker.locator(), + checker.stylist(), + )) + } else { + None + } + }) + }) + .collect::>>()?; + + Ok(Fix::unsafe_edits( + remove_import_edit, + add_import_edit + .into_edits() + .into_iter() + .chain(quote_reference_edits), ) + .isolate(Checker::isolation( + checker.semantic().parent_statement_id(node_id), + ))) +} + +/// Wrap a type annotation in quotes. +/// +/// This requires more than just wrapping the reference itself in quotes. For example: +/// - When quoting `Series` in `Series[pd.Timestamp]`, we want `"Series[pd.Timestamp]"`. +/// - When quoting `kubernetes` in `kubernetes.SecurityContext`, we want `"kubernetes.SecurityContext"`. +/// - When quoting `Series` in `Series["pd.Timestamp"]`, we want `"Series[pd.Timestamp]"`. +/// - When quoting `Series` in `Series[Literal["pd.Timestamp"]]`, we want `"Series[Literal['pd.Timestamp']]"`. +fn quote_annotation( + node_id: NodeId, + semantic: &SemanticModel, + locator: &Locator, + stylist: &Stylist, +) -> Result { + let expr = semantic.expression(node_id); + if let Some(parent_id) = semantic.parent_expression_id(node_id) { + match semantic.expression(parent_id) { + Expr::Subscript(parent) => { + if expr == parent.value.as_ref() { + // If we're quoting the value of a subscript, we need to quote the entire expression. + // For example, when quoting `DataFrame` in `DataFrame[int]`, we should generate + // `"DataFrame[int]"`. + // STOPSHIP(charlie): If the value contains a string literal, we need to remove the + // quotes from the literal. + return quote_annotation(parent_id, semantic, locator, stylist); + } + } + Expr::Attribute(parent) => { + if expr == parent.value.as_ref() { + // If we're quoting the value of a subscript, we need to quote the entire expression. + // For example, when quoting `DataFrame` in `DataFrame[int]`, we should generate + // `"DataFrame[int]"`. + // STOPSHIP(charlie): If the value contains a string literal, we need to remove the + // quotes from the literal. + return quote_annotation(parent_id, semantic, locator, stylist); + } + } + _ => {} + } + } + + let annotation = locator.slice(expr); + + // If the annotation already contains a quote, avoid attempting to re-quote it. + if annotation.contains('\'') || annotation.contains('"') { + return Err(anyhow::anyhow!("Annotation already contains a quote")); + } + + // If we're quoting a name, we need to quote the entire expression. + let quote = stylist.quote(); + let annotation = format!("{quote}{annotation}{quote}"); + Ok(Edit::range_replacement(annotation, expr.range())) } diff --git a/crates/ruff_linter/src/rules/flake8_type_checking/settings.rs b/crates/ruff_linter/src/rules/flake8_type_checking/settings.rs index 425f02fe550a41..811163860da554 100644 --- a/crates/ruff_linter/src/rules/flake8_type_checking/settings.rs +++ b/crates/ruff_linter/src/rules/flake8_type_checking/settings.rs @@ -14,7 +14,7 @@ impl Default for Settings { fn default() -> Self { Self { strict: false, - exempt_modules: vec!["typing".to_string()], + exempt_modules: vec!["typing".to_string(), "typing_extensions".to_string()], runtime_evaluated_base_classes: vec![], runtime_evaluated_decorators: vec![], } diff --git a/crates/ruff_linter/src/rules/flake8_type_checking/snapshots/ruff_linter__rules__flake8_type_checking__tests__typing-only-third-party-import_quote.py.snap b/crates/ruff_linter/src/rules/flake8_type_checking/snapshots/ruff_linter__rules__flake8_type_checking__tests__typing-only-third-party-import_quote.py.snap new file mode 100644 index 00000000000000..403277b9af8860 --- /dev/null +++ b/crates/ruff_linter/src/rules/flake8_type_checking/snapshots/ruff_linter__rules__flake8_type_checking__tests__typing-only-third-party-import_quote.py.snap @@ -0,0 +1,126 @@ +--- +source: crates/ruff_linter/src/rules/flake8_type_checking/mod.rs +--- +quote.py:2:24: TCH002 [*] Move third-party import `pandas.DataFrame` into a type-checking block + | +1 | def f(): +2 | from pandas import DataFrame + | ^^^^^^^^^ TCH002 +3 | +4 | def baz() -> DataFrame: + | + = help: Move into type-checking block + +ℹ Unsafe fix +1 |-def f(): + 1 |+from typing import TYPE_CHECKING + 2 |+ + 3 |+if TYPE_CHECKING: +2 4 | from pandas import DataFrame + 5 |+def f(): +3 6 | +4 |- def baz() -> DataFrame: + 7 |+ def baz() -> "DataFrame": +5 8 | ... +6 9 | +7 10 | + +quote.py:9:24: TCH002 [*] Move third-party import `pandas.DataFrame` into a type-checking block + | + 8 | def f(): + 9 | from pandas import DataFrame + | ^^^^^^^^^ TCH002 +10 | +11 | def baz() -> DataFrame[int]: + | + = help: Move into type-checking block + +ℹ Unsafe fix + 1 |+from typing import TYPE_CHECKING + 2 |+ + 3 |+if TYPE_CHECKING: + 4 |+ from pandas import DataFrame +1 5 | def f(): +2 6 | from pandas import DataFrame +3 7 | +-------------------------------------------------------------------------------- +6 10 | +7 11 | +8 12 | def f(): +9 |- from pandas import DataFrame +10 13 | +11 |- def baz() -> DataFrame[int]: + 14 |+ def baz() -> "DataFrame[int]": +12 15 | ... +13 16 | +14 17 | + +quote.py:16:24: TCH002 Move third-party import `pandas.DataFrame` into a type-checking block + | +15 | def f(): +16 | from pandas import DataFrame + | ^^^^^^^^^ TCH002 +17 | +18 | def baz() -> DataFrame["int"]: + | + = help: Move into type-checking block + +quote.py:23:22: TCH002 [*] Move third-party import `pandas` into a type-checking block + | +22 | def f(): +23 | import pandas as pd + | ^^ TCH002 +24 | +25 | def baz() -> pd.DataFrame: + | + = help: Move into type-checking block + +ℹ Unsafe fix + 1 |+from typing import TYPE_CHECKING + 2 |+ + 3 |+if TYPE_CHECKING: + 4 |+ import pandas as pd +1 5 | def f(): +2 6 | from pandas import DataFrame +3 7 | +-------------------------------------------------------------------------------- +20 24 | +21 25 | +22 26 | def f(): +23 |- import pandas as pd +24 27 | +25 |- def baz() -> pd.DataFrame: + 28 |+ def baz() -> "pd.DataFrame": +26 29 | ... +27 30 | +28 31 | + +quote.py:30:22: TCH002 [*] Move third-party import `pandas` into a type-checking block + | +29 | def f(): +30 | import pandas as pd + | ^^ TCH002 +31 | +32 | def baz() -> pd.DataFrame.Extra: + | + = help: Move into type-checking block + +ℹ Unsafe fix + 1 |+from typing import TYPE_CHECKING + 2 |+ + 3 |+if TYPE_CHECKING: + 4 |+ import pandas as pd +1 5 | def f(): +2 6 | from pandas import DataFrame +3 7 | +-------------------------------------------------------------------------------- +27 31 | +28 32 | +29 33 | def f(): +30 |- import pandas as pd +31 34 | +32 |- def baz() -> pd.DataFrame.Extra: + 35 |+ def baz() -> "pd.DataFrame.Extra": +33 36 | ... + + diff --git a/crates/ruff_python_semantic/src/model.rs b/crates/ruff_python_semantic/src/model.rs index 9cb3cebaa07ff2..b1a63fd1f9d8c5 100644 --- a/crates/ruff_python_semantic/src/model.rs +++ b/crates/ruff_python_semantic/src/model.rs @@ -291,9 +291,12 @@ impl<'a> SemanticModel<'a> { if let Some(binding_id) = self.scopes.global().get(name.id.as_str()) { if !self.bindings[binding_id].is_unbound() { // Mark the binding as used. - let reference_id = - self.resolved_references - .push(ScopeId::global(), name.range, self.flags); + let reference_id = self.resolved_references.push( + ScopeId::global(), + self.node_id, + name.range, + self.flags, + ); self.bindings[binding_id].references.push(reference_id); // Mark any submodule aliases as used. @@ -302,6 +305,7 @@ impl<'a> SemanticModel<'a> { { let reference_id = self.resolved_references.push( ScopeId::global(), + self.node_id, name.range, self.flags, ); @@ -356,18 +360,24 @@ impl<'a> SemanticModel<'a> { if let Some(binding_id) = scope.get(name.id.as_str()) { // Mark the binding as used. - let reference_id = - self.resolved_references - .push(self.scope_id, name.range, self.flags); + let reference_id = self.resolved_references.push( + self.scope_id, + self.node_id, + name.range, + self.flags, + ); self.bindings[binding_id].references.push(reference_id); // Mark any submodule aliases as used. if let Some(binding_id) = self.resolve_submodule(name.id.as_str(), scope_id, binding_id) { - let reference_id = - self.resolved_references - .push(self.scope_id, name.range, self.flags); + let reference_id = self.resolved_references.push( + self.scope_id, + self.node_id, + name.range, + self.flags, + ); self.bindings[binding_id].references.push(reference_id); } @@ -431,9 +441,12 @@ impl<'a> SemanticModel<'a> { // The `x` in `print(x)` should resolve to the `x` in `x = 1`. BindingKind::UnboundException(Some(binding_id)) => { // Mark the binding as used. - let reference_id = - self.resolved_references - .push(self.scope_id, name.range, self.flags); + let reference_id = self.resolved_references.push( + self.scope_id, + self.node_id, + name.range, + self.flags, + ); self.bindings[binding_id].references.push(reference_id); // Mark any submodule aliases as used. @@ -442,6 +455,7 @@ impl<'a> SemanticModel<'a> { { let reference_id = self.resolved_references.push( self.scope_id, + self.node_id, name.range, self.flags, ); @@ -979,6 +993,32 @@ impl<'a> SemanticModel<'a> { &self.nodes[node_id] } + /// Return the [`Expr`] corresponding to the given [`NodeId`]. + #[inline] + pub fn expression(&self, node_id: NodeId) -> &'a Expr { + self.nodes + .ancestor_ids(node_id) + .find_map(|id| self.nodes[id].as_expression()) + .expect("No expression found") + } + + /// Given a [`Expr`], return its parent, if any. + #[inline] + pub fn parent_expression(&self, node_id: NodeId) -> Option<&'a Expr> { + self.nodes + .ancestor_ids(node_id) + .filter_map(|id| self.nodes[id].as_expression()) + .nth(1) + } + + /// Given a [`NodeId`], return the [`NodeId`] of the parent expression, if any. + pub fn parent_expression_id(&self, node_id: NodeId) -> Option { + self.nodes + .ancestor_ids(node_id) + .filter(|id| self.nodes[*id].is_expression()) + .nth(1) + } + /// Return the [`Stmt`] corresponding to the given [`NodeId`]. #[inline] pub fn statement(&self, node_id: NodeId) -> &'a Stmt { @@ -1169,17 +1209,17 @@ impl<'a> SemanticModel<'a> { /// Add a reference to the given [`BindingId`] in the local scope. pub fn add_local_reference(&mut self, binding_id: BindingId, range: TextRange) { - let reference_id = self - .resolved_references - .push(self.scope_id, range, self.flags); + let reference_id = + self.resolved_references + .push(self.scope_id, self.node_id, range, self.flags); self.bindings[binding_id].references.push(reference_id); } /// Add a reference to the given [`BindingId`] in the global scope. pub fn add_global_reference(&mut self, binding_id: BindingId, range: TextRange) { - let reference_id = self - .resolved_references - .push(ScopeId::global(), range, self.flags); + let reference_id = + self.resolved_references + .push(ScopeId::global(), self.node_id, range, self.flags); self.bindings[binding_id].references.push(reference_id); } @@ -1282,10 +1322,16 @@ impl<'a> SemanticModel<'a> { .intersects(SemanticModelFlags::TYPING_ONLY_ANNOTATION) } - /// Return `true` if the model is in a runtime-required type annotation. - pub const fn in_runtime_annotation(&self) -> bool { + /// Return `true` if the context is in a runtime-evaluated type annotation. + pub const fn in_runtime_evaluated_annotation(&self) -> bool { self.flags - .intersects(SemanticModelFlags::RUNTIME_ANNOTATION) + .intersects(SemanticModelFlags::RUNTIME_EVALUATED_ANNOTATION) + } + + /// Return `true` if the context is in a runtime-required type annotation. + pub const fn in_runtime_required_annotation(&self) -> bool { + self.flags + .intersects(SemanticModelFlags::RUNTIME_REQUIRED_ANNOTATION) } /// Return `true` if the model is in a type definition. @@ -1457,8 +1503,9 @@ impl ShadowedBinding { bitflags! { /// Flags indicating the current model state. #[derive(Debug, Default, Copy, Clone, Eq, PartialEq)] - pub struct SemanticModelFlags: u16 { - /// The model is in a typing-time-only type annotation. + pub struct SemanticModelFlags: u32 { + /// The model is in a type annotation that will only be evaluated when running a type + /// checker. /// /// For example, the model could be visiting `int` in: /// ```python @@ -1473,7 +1520,7 @@ bitflags! { /// are any annotated assignments in module or class scopes. const TYPING_ONLY_ANNOTATION = 1 << 0; - /// The model is in a runtime type annotation. + /// The model is in a type annotation that will be evaluated at runtime. /// /// For example, the model could be visiting `int` in: /// ```python @@ -1487,7 +1534,27 @@ bitflags! { /// If `from __future__ import annotations` is used, all annotations are evaluated at /// typing time. Otherwise, all function argument annotations are evaluated at runtime, as /// are any annotated assignments in module or class scopes. - const RUNTIME_ANNOTATION = 1 << 1; + const RUNTIME_EVALUATED_ANNOTATION = 1 << 1; + + /// The model is in a type annotation that is _required_ to be available at runtime. + /// + /// For example, the context could be visiting `int` in: + /// ```python + /// from pydantic import BaseModel + /// + /// class Foo(BaseModel): + /// x: int + /// ``` + /// + /// In this case, Pydantic requires that the type annotation be available at runtime + /// in order to perform runtime type-checking. + /// + /// Unlike [`RUNTIME_EVALUATED_ANNOTATION`], annotations that are marked as + /// [`RUNTIME_REQUIRED_ANNOTATION`] cannot be deferred to typing time via conversion to a + /// forward reference (e.g., by wrapping the type in quotes), as the annotations are not + /// only required by the Python interpreter, but by runtime type checkers too. + const RUNTIME_REQUIRED_ANNOTATION = 1 << 2; + /// The model is in a type definition. /// @@ -1501,7 +1568,7 @@ bitflags! { /// All type annotations are also type definitions, but the converse is not true. /// In our example, `int` is a type definition but not a type annotation, as it /// doesn't appear in a type annotation context, but rather in a type definition. - const TYPE_DEFINITION = 1 << 2; + const TYPE_DEFINITION = 1 << 3; /// The model is in a (deferred) "simple" string type definition. /// @@ -1512,7 +1579,7 @@ bitflags! { /// /// "Simple" string type definitions are those that consist of a single string literal, /// as opposed to an implicitly concatenated string literal. - const SIMPLE_STRING_TYPE_DEFINITION = 1 << 3; + const SIMPLE_STRING_TYPE_DEFINITION = 1 << 4; /// The model is in a (deferred) "complex" string type definition. /// @@ -1523,7 +1590,7 @@ bitflags! { /// /// "Complex" string type definitions are those that consist of a implicitly concatenated /// string literals. These are uncommon but valid. - const COMPLEX_STRING_TYPE_DEFINITION = 1 << 4; + const COMPLEX_STRING_TYPE_DEFINITION = 1 << 5; /// The model is in a (deferred) `__future__` type definition. /// @@ -1536,7 +1603,7 @@ bitflags! { /// /// `__future__`-style type annotations are only enabled if the `annotations` feature /// is enabled via `from __future__ import annotations`. - const FUTURE_TYPE_DEFINITION = 1 << 5; + const FUTURE_TYPE_DEFINITION = 1 << 6; /// The model is in an exception handler. /// @@ -1547,7 +1614,7 @@ bitflags! { /// except Exception: /// x: int = 1 /// ``` - const EXCEPTION_HANDLER = 1 << 6; + const EXCEPTION_HANDLER = 1 << 7; /// The model is in an f-string. /// @@ -1555,7 +1622,7 @@ bitflags! { /// ```python /// f'{x}' /// ``` - const F_STRING = 1 << 7; + const F_STRING = 1 << 8; /// The model is in a boolean test. /// @@ -1567,7 +1634,7 @@ bitflags! { /// /// The implication is that the actual value returned by the current expression is /// not used, only its truthiness. - const BOOLEAN_TEST = 1 << 8; + const BOOLEAN_TEST = 1 << 9; /// The model is in a `typing::Literal` annotation. /// @@ -1576,7 +1643,7 @@ bitflags! { /// def f(x: Literal["A", "B", "C"]): /// ... /// ``` - const TYPING_LITERAL = 1 << 9; + const TYPING_LITERAL = 1 << 10; /// The model is in a subscript expression. /// @@ -1584,7 +1651,7 @@ bitflags! { /// ```python /// x["a"]["b"] /// ``` - const SUBSCRIPT = 1 << 10; + const SUBSCRIPT = 1 << 11; /// The model is in a type-checking block. /// @@ -1596,7 +1663,7 @@ bitflags! { /// if TYPE_CHECKING: /// x: int = 1 /// ``` - const TYPE_CHECKING_BLOCK = 1 << 11; + const TYPE_CHECKING_BLOCK = 1 << 12; /// The model has traversed past the "top-of-file" import boundary. /// @@ -1609,7 +1676,7 @@ bitflags! { /// /// x: int = 1 /// ``` - const IMPORT_BOUNDARY = 1 << 12; + const IMPORT_BOUNDARY = 1 << 13; /// The model has traversed past the `__future__` import boundary. /// @@ -1624,7 +1691,7 @@ bitflags! { /// /// Python considers it a syntax error to import from `__future__` after /// any other non-`__future__`-importing statements. - const FUTURES_BOUNDARY = 1 << 13; + const FUTURES_BOUNDARY = 1 << 14; /// `__future__`-style type annotations are enabled in this model. /// @@ -1636,7 +1703,7 @@ bitflags! { /// def f(x: int) -> int: /// ... /// ``` - const FUTURE_ANNOTATIONS = 1 << 14; + const FUTURE_ANNOTATIONS = 1 << 15; /// The model is in a type parameter definition. /// @@ -1646,10 +1713,11 @@ bitflags! { /// /// Record = TypeVar("Record") /// - const TYPE_PARAM_DEFINITION = 1 << 15; + const TYPE_PARAM_DEFINITION = 1 << 16; /// The context is in any type annotation. - const ANNOTATION = Self::TYPING_ONLY_ANNOTATION.bits() | Self::RUNTIME_ANNOTATION.bits(); + const ANNOTATION = Self::TYPING_ONLY_ANNOTATION.bits() | Self::RUNTIME_EVALUATED_ANNOTATION.bits() | Self::RUNTIME_REQUIRED_ANNOTATION.bits(); + /// The context is in any string type definition. const STRING_TYPE_DEFINITION = Self::SIMPLE_STRING_TYPE_DEFINITION.bits() diff --git a/crates/ruff_python_semantic/src/reference.rs b/crates/ruff_python_semantic/src/reference.rs index a963895b21e697..7e24ea2aad55ad 100644 --- a/crates/ruff_python_semantic/src/reference.rs +++ b/crates/ruff_python_semantic/src/reference.rs @@ -8,11 +8,14 @@ use ruff_text_size::{Ranged, TextRange}; use crate::context::ExecutionContext; use crate::scope::ScopeId; -use crate::{Exceptions, SemanticModelFlags}; +use crate::{Exceptions, NodeId, SemanticModelFlags}; /// A resolved read reference to a name in a program. #[derive(Debug, Clone)] pub struct ResolvedReference { + /// The expression that the reference occurs in. `None` if the reference is a global + /// reference or a reference via an augmented assignment. + node_id: Option, /// The scope in which the reference is defined. scope_id: ScopeId, /// The range of the reference in the source code. @@ -22,6 +25,11 @@ pub struct ResolvedReference { } impl ResolvedReference { + /// The expression that the reference occurs in. + pub const fn node_id(&self) -> Option { + self.node_id + } + /// The scope in which the reference is defined. pub const fn scope_id(&self) -> ScopeId { self.scope_id @@ -35,6 +43,120 @@ impl ResolvedReference { ExecutionContext::Runtime } } + + /// Return `true` if the context is in a type annotation. + pub const fn in_annotation(&self) -> bool { + self.flags.intersects(SemanticModelFlags::ANNOTATION) + } + + /// Return `true` if the context is in a typing-only type annotation. + pub const fn in_typing_only_annotation(&self) -> bool { + self.flags + .intersects(SemanticModelFlags::TYPING_ONLY_ANNOTATION) + } + + /// Return `true` if the context is in a runtime-required type annotation. + pub const fn in_runtime_evaluated_annotation(&self) -> bool { + self.flags + .intersects(SemanticModelFlags::RUNTIME_EVALUATED_ANNOTATION) + } + + /// Return `true` if the context is in a type definition. + pub const fn in_type_definition(&self) -> bool { + self.flags.intersects(SemanticModelFlags::TYPE_DEFINITION) + } + + /// Return `true` if the context is in a "simple" string type definition. + pub const fn in_simple_string_type_definition(&self) -> bool { + self.flags + .intersects(SemanticModelFlags::SIMPLE_STRING_TYPE_DEFINITION) + } + + /// Return `true` if the context is in a "complex" string type definition. + pub const fn in_complex_string_type_definition(&self) -> bool { + self.flags + .intersects(SemanticModelFlags::COMPLEX_STRING_TYPE_DEFINITION) + } + + /// Return `true` if the context is in a `__future__` type definition. + pub const fn in_future_type_definition(&self) -> bool { + self.flags + .intersects(SemanticModelFlags::FUTURE_TYPE_DEFINITION) + } + + /// Return `true` if the context is in any kind of deferred type definition. + pub const fn in_deferred_type_definition(&self) -> bool { + self.flags + .intersects(SemanticModelFlags::DEFERRED_TYPE_DEFINITION) + } + + /// Return `true` if the context is in a forward type reference. + /// + /// Includes deferred string types, and future types in annotations. + /// + /// ## Examples + /// ```python + /// from __future__ import annotations + /// + /// from threading import Thread + /// + /// + /// x: Thread # Forward reference + /// cast("Thread", x) # Forward reference + /// cast(Thread, x) # Non-forward reference + /// ``` + pub const fn in_forward_reference(&self) -> bool { + self.in_simple_string_type_definition() + || self.in_complex_string_type_definition() + || (self.in_future_type_definition() && self.in_typing_only_annotation()) + } + + /// Return `true` if the context is in an exception handler. + pub const fn in_exception_handler(&self) -> bool { + self.flags.intersects(SemanticModelFlags::EXCEPTION_HANDLER) + } + + /// Return `true` if the context is in an f-string. + pub const fn in_f_string(&self) -> bool { + self.flags.intersects(SemanticModelFlags::F_STRING) + } + + /// Return `true` if the context is in boolean test. + pub const fn in_boolean_test(&self) -> bool { + self.flags.intersects(SemanticModelFlags::BOOLEAN_TEST) + } + + /// Return `true` if the context is in a `typing::Literal` annotation. + pub const fn in_typing_literal(&self) -> bool { + self.flags.intersects(SemanticModelFlags::TYPING_LITERAL) + } + + /// Return `true` if the context is in a subscript expression. + pub const fn in_subscript(&self) -> bool { + self.flags.intersects(SemanticModelFlags::SUBSCRIPT) + } + + /// Return `true` if the context is in a type-checking block. + pub const fn in_type_checking_block(&self) -> bool { + self.flags + .intersects(SemanticModelFlags::TYPE_CHECKING_BLOCK) + } + + /// Return `true` if the context has traversed past the "top-of-file" import boundary. + pub const fn seen_import_boundary(&self) -> bool { + self.flags.intersects(SemanticModelFlags::IMPORT_BOUNDARY) + } + + /// Return `true` if the context has traverse past the `__future__` import boundary. + pub const fn seen_futures_boundary(&self) -> bool { + self.flags.intersects(SemanticModelFlags::FUTURES_BOUNDARY) + } + + /// Return `true` if `__future__`-style type annotations are enabled. + pub const fn future_annotations(&self) -> bool { + self.flags + .intersects(SemanticModelFlags::FUTURE_ANNOTATIONS) + } } impl Ranged for ResolvedReference { @@ -57,10 +179,12 @@ impl ResolvedReferences { pub(crate) fn push( &mut self, scope_id: ScopeId, + node_id: Option, range: TextRange, flags: SemanticModelFlags, ) -> ResolvedReferenceId { self.0.push(ResolvedReference { + node_id, scope_id, range, flags,