Skip to content

Commit

Permalink
Add quote support
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh committed Aug 20, 2023
1 parent 1b7e4a1 commit a457f8b
Show file tree
Hide file tree
Showing 8 changed files with 555 additions and 99 deletions.
28 changes: 28 additions & 0 deletions crates/ruff/resources/test/fixtures/flake8_type_checking/TCH002.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,31 @@ def f():
from module import Member

x: Member = 1


def f():
from pandas import DataFrame

def baz() -> DataFrame:
...


def f():
from pandas import DataFrame

def baz() -> DataFrame[int]:
...


def f():
from pandas import DataFrame

def baz() -> DataFrame["int"]:
...


def f():
import pandas as pd

def baz() -> pd.DataFrame:
...
96 changes: 63 additions & 33 deletions crates/ruff/src/checkers/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ where
{
if let Some(expr) = &parameter_with_default.parameter.annotation {
if runtime_annotation {
self.visit_runtime_annotation(expr);
self.visit_runtime_evaluated_annotation(expr);
} else {
self.visit_annotation(expr);
};
Expand All @@ -504,7 +504,7 @@ where
if let Some(arg) = &parameters.vararg {
if let Some(expr) = &arg.annotation {
if runtime_annotation {
self.visit_runtime_annotation(expr);
self.visit_runtime_evaluated_annotation(expr);
} else {
self.visit_annotation(expr);
};
Expand All @@ -513,15 +513,15 @@ where
if let Some(arg) = &parameters.kwarg {
if let Some(expr) = &arg.annotation {
if runtime_annotation {
self.visit_runtime_annotation(expr);
self.visit_runtime_evaluated_annotation(expr);
} else {
self.visit_annotation(expr);
};
}
}
for expr in returns {
if runtime_annotation {
self.visit_runtime_annotation(expr);
self.visit_runtime_evaluated_annotation(expr);
} else {
self.visit_annotation(expr);
};
Expand Down Expand Up @@ -652,38 +652,60 @@ where
value,
..
}) => {
// If we're in a class or module scope, then the annotation needs to be
// available at runtime.
// See: https://docs.python.org/3/reference/simple_stmts.html#annotated-assignment-statements
let runtime_annotation = if self.semantic.future_annotations() {
if self.semantic.current_scope().kind.is_class() {
let baseclasses = &self
.settings
.flake8_type_checking
.runtime_evaluated_base_classes;
let decorators = &self
.settings
.flake8_type_checking
.runtime_evaluated_decorators;
flake8_type_checking::helpers::runtime_evaluated(
enum AnnotationKind {
RuntimeRequired,
RuntimeEvaluated,
TypingOnly,
}

fn annotation_kind(
semantic: &SemanticModel,
settings: &Settings,
) -> AnnotationKind {
// If the annotation is in a class, and that class is marked as
// runtime-evaluated, treat the annotation as runtime-required.
// TODO(charlie): We could also include function calls here.
if semantic.current_scope().kind.is_class() {
let baseclasses =
&settings.flake8_type_checking.runtime_evaluated_base_classes;
let decorators =
&settings.flake8_type_checking.runtime_evaluated_decorators;
if flake8_type_checking::helpers::runtime_required(
baseclasses,
decorators,
&self.semantic,
)
} else {
false
semantic,
) {
return AnnotationKind::RuntimeRequired;
}
}
} else {
matches!(
self.semantic.current_scope().kind,

// If `__future__` annotations are enabled, then annotations are never evaluated
// at runtime, so we can treat them as typing-only.
if semantic.future_annotations() {
return AnnotationKind::TypingOnly;
}

// Otherwise, if we're in a class or module scope, then the annotation needs to
// be available at runtime.
// See: https://docs.python.org/3/reference/simple_stmts.html#annotated-assignment-statements
if matches!(
semantic.current_scope().kind,
ScopeKind::Class(_) | ScopeKind::Module
)
};
) {
return AnnotationKind::RuntimeEvaluated;
}

if runtime_annotation {
self.visit_runtime_annotation(annotation);
} else {
self.visit_annotation(annotation);
AnnotationKind::TypingOnly
}

match annotation_kind(&self.semantic, self.settings) {
AnnotationKind::RuntimeRequired => {
self.visit_runtime_required_annotation(annotation);
}
AnnotationKind::RuntimeEvaluated => {
self.visit_runtime_evaluated_annotation(annotation);
}
AnnotationKind::TypingOnly => self.visit_annotation(annotation),
}
if let Some(expr) = value {
if self.semantic.match_typing_expr(annotation, "TypeAlias") {
Expand Down Expand Up @@ -1479,10 +1501,18 @@ impl<'a> Checker<'a> {
self.semantic.flags = snapshot;
}

/// Visit an [`Expr`], and treat it as a runtime-evaluated type annotation.
fn visit_runtime_evaluated_annotation(&mut self, expr: &'a Expr) {
let snapshot = self.semantic.flags;
self.semantic.flags |= SemanticModelFlags::RUNTIME_EVALUATED_ANNOTATION;
self.visit_type_definition(expr);
self.semantic.flags = snapshot;
}

/// Visit an [`Expr`], and treat it as a runtime-required type annotation.
fn visit_runtime_annotation(&mut self, expr: &'a Expr) {
fn visit_runtime_required_annotation(&mut self, expr: &'a Expr) {
let snapshot = self.semantic.flags;
self.semantic.flags |= SemanticModelFlags::RUNTIME_ANNOTATION;
self.semantic.flags |= SemanticModelFlags::RUNTIME_REQUIRED_ANNOTATION;
self.visit_type_definition(expr);
self.semantic.flags = snapshot;
}
Expand Down
21 changes: 15 additions & 6 deletions crates/ruff/src/rules/flake8_type_checking/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,40 @@ pub(crate) fn is_valid_runtime_import(binding: &Binding, semantic: &SemanticMode
binding.context.is_runtime()
&& binding
.references()
.any(|reference_id| semantic.reference(reference_id).context().is_runtime())
.map(|reference_id| semantic.reference(reference_id))
.any(|reference| {
// This is like: typing context _or_ a runtime-required type annotation (since
// we're willing to quote it).
!(reference.in_type_checking_block()
|| reference.in_typing_only_annotation()
|| reference.in_runtime_evaluated_annotation()
|| reference.in_complex_string_type_definition()
|| reference.in_simple_string_type_definition())
})
} else {
false
}
}

pub(crate) fn runtime_evaluated(
pub(crate) fn runtime_required(
base_classes: &[String],
decorators: &[String],
semantic: &SemanticModel,
) -> bool {
if !base_classes.is_empty() {
if runtime_evaluated_base_class(base_classes, semantic) {
if runtime_required_base_class(base_classes, semantic) {
return true;
}
}
if !decorators.is_empty() {
if runtime_evaluated_decorators(decorators, semantic) {
if runtime_required_decorators(decorators, semantic) {
return true;
}
}
false
}

fn runtime_evaluated_base_class(base_classes: &[String], semantic: &SemanticModel) -> bool {
fn runtime_required_base_class(base_classes: &[String], semantic: &SemanticModel) -> bool {
let ScopeKind::Class(class_def) = &semantic.current_scope().kind else {
return false;
};
Expand All @@ -48,7 +57,7 @@ fn runtime_evaluated_base_class(base_classes: &[String], semantic: &SemanticMode
})
}

fn runtime_evaluated_decorators(decorators: &[String], semantic: &SemanticModel) -> bool {
fn runtime_required_decorators(decorators: &[String], semantic: &SemanticModel) -> bool {
let ScopeKind::Class(class_def) = &semantic.current_scope().kind else {
return false;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ use std::borrow::Cow;
use anyhow::Result;
use rustc_hash::FxHashMap;

use ruff_diagnostics::{AutofixKind, Diagnostic, DiagnosticKind, Fix, Violation};
use ruff_diagnostics::{AutofixKind, Diagnostic, DiagnosticKind, Edit, Fix, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::Ranged;
use ruff_python_codegen::Stylist;
use ruff_python_semantic::{AnyImport, Binding, Imported, ResolvedReferenceId, Scope, StatementId};
use ruff_text_size::TextRange;
use ruff_source_file::Locator;
use ruff_text_size::{TextLen, TextRange, TextSize};

use crate::autofix;
use crate::checkers::ast::Checker;
Expand Down Expand Up @@ -254,13 +256,19 @@ pub(crate) fn typing_only_runtime_import(
};

if binding.context.is_runtime()
&& binding.references().all(|reference_id| {
checker
.semantic()
.reference(reference_id)
.context()
.is_typing()
})
&& binding
.references()
.map(|reference_id| checker.semantic().reference(reference_id))
.all(|reference| {
// All references should be in a typing context _or_ a runtime-evaluated
// annotation (as opposed to a runtime-required annotation), which we can
// quote.
reference.in_type_checking_block()
|| reference.in_typing_only_annotation()
|| reference.in_runtime_evaluated_annotation()
|| reference.in_complex_string_type_definition()
|| reference.in_simple_string_type_definition()
})
{
let qualified_name = import.qualified_name();

Expand Down Expand Up @@ -309,6 +317,7 @@ pub(crate) fn typing_only_runtime_import(
let import = ImportBinding {
import,
reference_id,
binding,
range: binding.range(),
parent_range: binding.parent_range(checker.semantic()),
};
Expand Down Expand Up @@ -383,6 +392,8 @@ pub(crate) fn typing_only_runtime_import(
struct ImportBinding<'a> {
/// The qualified name of the import (e.g., `typing.List` for `from typing import List`).
import: AnyImport<'a>,
/// The binding for the imported symbol.
binding: &'a Binding<'a>,
/// The first reference to the imported symbol.
reference_id: ResolvedReferenceId,
/// The trimmed range of the import (e.g., `List` in `from typing import List`).
Expand Down Expand Up @@ -489,8 +500,98 @@ fn fix_imports(
checker.source_type,
)?;

Ok(
Fix::suggested_edits(remove_import_edit, add_import_edit.into_edits())
.isolate(checker.parent_isolation()),
// Step 3) Quote any runtime usages of the referenced symbol.
let quote_reference_edits = imports.iter().flat_map(|ImportBinding { binding, .. }| {
binding.references.iter().filter_map(|reference_id| {
let reference = checker.semantic().reference(*reference_id);
if reference.context().is_runtime() {
Some(quote_annotation(
reference.range(),
checker.locator(),
checker.stylist(),
))
} else {
None
}
})
});

Ok(Fix::suggested_edits(
remove_import_edit,
add_import_edit
.into_edits()
.into_iter()
.chain(quote_reference_edits),
)
.isolate(checker.parent_isolation()))
}

/// Quote a type annotation.
///
/// This requires more than wrapping the reference in quotes. For example:
/// - When quoting `Series` in `Series[pd.Timestamp]`, we want `"Series[pd.Timestamp]"`.
/// - When quoting `kubernetes` in `kubernetes.SecurityContext`, we want `"kubernetes.SecurityContext"`.
/// - When quoting `Series` in `Series["pd.Timestamp"]`, we want `"Series[pd.Timestamp]"`.
fn quote_annotation(range: TextRange, locator: &Locator, stylist: &Stylist) -> Edit {
// Expand the annotation to the end of the reference.
let mut depth = 0u32;
let mut len = TextSize::default();
let mut annotation = String::with_capacity(range.len().into());
for c in locator.after(range.start()).chars() {
match c {
'[' => depth += 1,
']' => {
// Ex) Quoting `int` in `DataFrame[int]`, which should expand until the end of the
// `int` symbol`.
if depth == 0 {
break;
}

depth -= 1;

// Ex) Quoting `DataFrame` in `DataFrame[int]`, which should expand until the end
// of the subscript.
if depth == 0 {
annotation.push(c);
len += c.text_len();
break;
}
}
'.' => {
// Expand attributes.
}
'a'..='z' | 'A'..='Z' | '_' | '0'..='9' => {
// Expand identifiers.
}
'"' | '\'' => {
// Skip quotes.
// TODO(charlie): Retain escaped quotes, and quotes in literals.
len += c.text_len();
continue;
}
'\n' | '\r' if depth > 0 => {
// If we hit a newline, fallback to replacing the range. This can be ugly, but is
// better than not quoting at all.
let annotation = locator.slice(range);
let quote = stylist.quote();
let annotation = format!("{quote}{annotation}{quote}");
return Edit::range_replacement(annotation, range);
}
_ => {
// If we hit a space, or a parenthesis, or any other character (and we're not in
// a subscript), we're done.
if depth == 0 {
break;
}
}
}
annotation.push(c);
len += c.text_len();
}

// Wrap in quotes.
let quote = stylist.quote();
let annotation = format!("{quote}{annotation}{quote}");

Edit::range_replacement(annotation, TextRange::at(range.start(), len))
}
Loading

0 comments on commit a457f8b

Please sign in to comment.