Skip to content

Commit

Permalink
[flake8-type-checking] Skip quoting annotation if it becomes invali…
Browse files Browse the repository at this point in the history
…d syntax (`TCH001`) (#14285)

Fix: #13934 

## Summary

Current implementation has a bug when the current annotation contains a
string with single and double quotes.

TL;DR: I think these cases happen less than other use cases of Literal.
So instead of fixing them we skip the fix in those cases.

One of the problematic cases:

```
from typing import Literal
from third_party import Type

def error(self, type1: Type[Literal["'"]]):
    pass
```

The outcome is:

```
- def error(self, type1: Type[Literal["'"]]):
+ def error(self, type1: "Type[Literal[''']]"):
```

While it should be:

```
"Type[Literal['\'']"
```

The solution in this case is that we check if there’s any quotes same as
the quote style we want to use for this Literal parameter then escape
that same quote used in the string.

Also this case is not uncommon to have:
<https://grep.app/search?current=2&q=Literal["'>

But this can get more complicated for example in case of:

```
- def error(self, type1: Type[Literal["\'"]]):
+ def error(self, type1: "Type[Literal[''']]"):
```

Here we escaped the inner quote but in the generated annotation it gets
removed. Then we flip the quote style of the Literal paramter and the
formatting is wrong.

In this case the solution is more complicated.
1. When generating the string of the source code preserve the backslash.
2. After we have the annotation check if there isn’t any escaped quote
of the same type we want to use for the Literal parameter. In this case
check if we have any `’` without `\` before them. This can get more
complicated since there can be multiple backslashes so checking for only
`\’` won’t be enough.

Another problem is when the string contains `\n`. In case of
`Type[Literal["\n"]]` we generate `'Type[Literal["\n"]]'` and both
pyright and mypy reject this annotation.

https://pyright-play.net/?code=GYJw9gtgBALgngBwJYDsDmUkQWEMoAySMApiAIYA2AUAMaXkDOjUAKoiQNqsC6AXFAB0w6tQAmJYLBKMYAfQCOAVzCk5tMChjlUjOQCNytANaMGjABYAKRiUrAANLA4BGAQHJ2CLkVIVKnABEADoogTw87gCUfNRQ8VAITIyiElKksooqahpaOih6hiZmTNa29k7w3m5sHJy%2BZFRBoeE8MXEJScxAA

## Test Plan

I added test cases for the original code in the reported issue and two
more cases for backslash and new line.

---------

Co-authored-by: Dhruv Manilawala <dhruvmanila@gmail.com>
  • Loading branch information
Glyphack and dhruvmanila authored Nov 15, 2024
1 parent 1f82731 commit 6591775
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,18 @@ def f():
def test_annotated_non_typing_reference(user: Annotated[str, Depends(get_foo)]):
pass


def f():
from typing import Literal
from third_party import Type

def test_string_contains_opposite_quote_do_not_fix(self, type1: Type[Literal["'"]], type2: Type[Literal["\'"]]):
pass


def f():
from typing import Literal
from third_party import Type

def test_quote_contains_backslash(self, type1: Type[Literal["\n"]], type2: Type[Literal["\""]]):
pass
45 changes: 32 additions & 13 deletions crates/ruff_linter/src/rules/flake8_type_checking/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use anyhow::Result;
use ast::visitor::source_order;
use ruff_python_ast::visitor::source_order::SourceOrderVisitor;
use std::cmp::Reverse;

use anyhow::Result;

use ruff_diagnostics::Edit;
use ruff_python_ast::helpers::{map_callable, map_subscript};
use ruff_python_ast::name::QualifiedName;
use ruff_python_ast::visitor::source_order::{SourceOrderVisitor, TraversalSignal};
use ruff_python_ast::{self as ast, Decorator, Expr};
use ruff_python_codegen::{Generator, Stylist};
use ruff_python_semantic::{
Expand Down Expand Up @@ -221,8 +221,8 @@ pub(crate) fn is_singledispatch_implementation(
/// This requires more than just wrapping the reference itself in quotes. For example:
/// - When quoting `Series` in `Series[pd.Timestamp]`, we want `"Series[pd.Timestamp]"`.
/// - When quoting `kubernetes` in `kubernetes.SecurityContext`, we want `"kubernetes.SecurityContext"`.
/// - When quoting `Series` in `Series["pd.Timestamp"]`, we want `"Series[pd.Timestamp]"`. (This is currently unsupported.)
/// - When quoting `Series` in `Series[Literal["pd.Timestamp"]]`, we want `"Series[Literal['pd.Timestamp']]"`. (This is currently unsupported.)
/// - When quoting `Series` in `Series["pd.Timestamp"]`, we want `"Series[pd.Timestamp]"`.
/// - When quoting `Series` in `Series[Literal["pd.Timestamp"]]`, we want `"Series[Literal['pd.Timestamp']]"`.
///
/// In general, when expanding a component of a call chain, we want to quote the entire call chain.
pub(crate) fn quote_annotation(
Expand Down Expand Up @@ -272,7 +272,7 @@ pub(crate) fn quote_annotation(
let quote = stylist.quote();
let mut quote_annotator = QuoteAnnotator::new(semantic, stylist);
quote_annotator.visit_expr(expr);
let annotation = quote_annotator.into_annotation();
let annotation = quote_annotator.into_annotation()?;

Ok(Edit::range_replacement(
format!("{quote}{annotation}{quote}"),
Expand Down Expand Up @@ -313,6 +313,7 @@ pub(crate) struct QuoteAnnotator<'a> {
semantic: &'a SemanticModel<'a>,
state: Vec<QuoteAnnotatorState>,
annotation: String,
cannot_fix: bool,
}

impl<'a> QuoteAnnotator<'a> {
Expand All @@ -322,15 +323,30 @@ impl<'a> QuoteAnnotator<'a> {
semantic,
state: Vec::new(),
annotation: String::new(),
cannot_fix: false,
}
}

fn into_annotation(self) -> String {
self.annotation
fn into_annotation(self) -> Result<String> {
if self.cannot_fix {
Err(anyhow::anyhow!(
"Cannot quote annotation because it already contains opposite quote or escape character"
))
} else {
Ok(self.annotation)
}
}
}

impl<'a> source_order::SourceOrderVisitor<'a> for QuoteAnnotator<'a> {
impl<'a> SourceOrderVisitor<'a> for QuoteAnnotator<'a> {
fn enter_node(&mut self, _node: ast::AnyNodeRef<'a>) -> TraversalSignal {
if self.cannot_fix {
TraversalSignal::Skip
} else {
TraversalSignal::Traverse
}
}

fn visit_expr(&mut self, expr: &'a Expr) {
let generator = Generator::from(self.stylist);

Expand Down Expand Up @@ -388,10 +404,13 @@ impl<'a> source_order::SourceOrderVisitor<'a> for QuoteAnnotator<'a> {
let source = match self.state.last().copied() {
Some(QuoteAnnotatorState::Literal | QuoteAnnotatorState::AnnotatedRest) => {
let mut source = generator.expr(expr);
source = source.replace(
self.stylist.quote().as_char(),
&self.stylist.quote().opposite().as_char().to_string(),
);
let opposite_quote = &self.stylist.quote().opposite().as_char().to_string();
// If the quotes we are going to insert in this source already exists set the auto quote outcome
// to failed. Because this means we are inserting quotes that are in the string and they collect.
if source.contains(opposite_quote) || source.contains('\\') {
self.cannot_fix = true;
}
source = source.replace(self.stylist.quote().as_char(), opposite_quote);
source
}
None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
---
source: crates/ruff_linter/src/rules/flake8_type_checking/mod.rs
snapshot_kind: text
---
quote.py:2:24: TCH002 [*] Move third-party import `pandas.DataFrame` into a type-checking block
|
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
---
source: crates/ruff_linter/src/rules/flake8_type_checking/mod.rs
snapshot_kind: text
---
quote2.py:2:44: TCH002 [*] Move third-party import `django.contrib.auth.models.AbstractBaseUser` into a type-checking block
|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,26 @@ quote3.py:40:37: TCH002 [*] Move third-party import `django.contrib.auth.models`
45 |+ def test_attribute_typing_literal(arg: 'models.AbstractBaseUser[Literal["admin"]]'):
43 46 | pass
44 47 |
45 48 |
45 48 |

quote3.py:59:29: TCH002 Move third-party import `third_party.Type` into a type-checking block
|
57 | def f():
58 | from typing import Literal
59 | from third_party import Type
| ^^^^ TCH002
60 |
61 | def test_string_contains_opposite_quote_do_not_fix(self, type1: Type[Literal["'"]], type2: Type[Literal["\'"]]):
|
= help: Move into type-checking block

quote3.py:67:29: TCH002 Move third-party import `third_party.Type` into a type-checking block
|
65 | def f():
66 | from typing import Literal
67 | from third_party import Type
| ^^^^ TCH002
68 |
69 | def test_quote_contains_backslash(self, type1: Type[Literal["\n"]], type2: Type[Literal["\""]]):
|
= help: Move into type-checking block

0 comments on commit 6591775

Please sign in to comment.