Skip to content

Commit

Permalink
Insert newline after nested function or class statements (#7946)
Browse files Browse the repository at this point in the history
**Summary** Insert a newline after nested function and class
definitions, unless there is a trailing own line comment.

We need to e.g. format
```python
if platform.system() == "Linux":
    if sys.version > (3, 10):
        def f():
            print("old")
    else:
        def f():
            print("new")
    f()
```
as
```python
if platform.system() == "Linux":
    if sys.version > (3, 10):

        def f():
            print("old")

    else:

        def f():
            print("new")

    f()
```
even though `f()` is directly preceded by an if statement, not a
function or class definition. See the comments and fixtures for trailing
own line comment handling.

**Test Plan** I checked that the new content of `newlines.py` matches
black's formatting.

---------

Co-authored-by: Charlie Marsh <charlie.r.marsh@gmail.com>
  • Loading branch information
konstin and charliermarsh authored Oct 18, 2023
1 parent dda4ced commit 0c3123e
Show file tree
Hide file tree
Showing 8 changed files with 430 additions and 77 deletions.
66 changes: 63 additions & 3 deletions crates/ruff_python_ast/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4817,7 +4817,7 @@ pub enum AnyNodeRef<'a> {
ElifElseClause(&'a ast::ElifElseClause),
}

impl AnyNodeRef<'_> {
impl<'a> AnyNodeRef<'a> {
pub fn as_ptr(&self) -> NonNull<()> {
match self {
AnyNodeRef::ModModule(node) => NonNull::from(*node).cast(),
Expand Down Expand Up @@ -5456,9 +5456,9 @@ impl AnyNodeRef<'_> {
)
}

pub fn visit_preorder<'a, V>(&'a self, visitor: &mut V)
pub fn visit_preorder<'b, V>(&'b self, visitor: &mut V)
where
V: PreorderVisitor<'a> + ?Sized,
V: PreorderVisitor<'b> + ?Sized,
{
match self {
AnyNodeRef::ModModule(node) => node.visit_preorder(visitor),
Expand Down Expand Up @@ -5544,6 +5544,66 @@ impl AnyNodeRef<'_> {
AnyNodeRef::ElifElseClause(node) => node.visit_preorder(visitor),
}
}

/// The last child of the last branch, if the node has multiple branches.
pub fn last_child_in_body(&self) -> Option<AnyNodeRef<'a>> {
let body = match self {
AnyNodeRef::StmtFunctionDef(ast::StmtFunctionDef { body, .. })
| AnyNodeRef::StmtClassDef(ast::StmtClassDef { body, .. })
| AnyNodeRef::StmtWith(ast::StmtWith { body, .. })
| AnyNodeRef::MatchCase(MatchCase { body, .. })
| AnyNodeRef::ExceptHandlerExceptHandler(ast::ExceptHandlerExceptHandler {
body,
..
})
| AnyNodeRef::ElifElseClause(ast::ElifElseClause { body, .. }) => body,
AnyNodeRef::StmtIf(ast::StmtIf {
body,
elif_else_clauses,
..
}) => elif_else_clauses.last().map_or(body, |clause| &clause.body),

AnyNodeRef::StmtFor(ast::StmtFor { body, orelse, .. })
| AnyNodeRef::StmtWhile(ast::StmtWhile { body, orelse, .. }) => {
if orelse.is_empty() {
body
} else {
orelse
}
}

AnyNodeRef::StmtMatch(ast::StmtMatch { cases, .. }) => {
return cases.last().map(AnyNodeRef::from);
}

AnyNodeRef::StmtTry(ast::StmtTry {
body,
handlers,
orelse,
finalbody,
..
}) => {
if finalbody.is_empty() {
if orelse.is_empty() {
if handlers.is_empty() {
body
} else {
return handlers.last().map(AnyNodeRef::from);
}
} else {
orelse
}
} else {
finalbody
}
}

// Not a node that contains an indented child node.
_ => return None,
};

body.last().map(AnyNodeRef::from)
}
}

impl<'a> From<&'a ast::ModModule> for AnyNodeRef<'a> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
###
# Blank lines around functions
###
import sys

x = 1

Expand Down Expand Up @@ -159,3 +160,97 @@ def f():
# comment

x = 1


def f():
if True:

def double(s):
return s + s
print("below function")
if True:

class A:
x = 1
print("below class")
if True:

def double(s):
return s + s
#
print("below comment function")
if True:

class A:
x = 1
#
print("below comment class")
if True:

def double(s):
return s + s
#
print("below comment function 2")
if True:

def double(s):
return s + s
#
def outer():
def inner():
pass
print("below nested functions")

if True:

def double(s):
return s + s
print("below function")
if True:

class A:
x = 1
print("below class")
def outer():
def inner():
pass
print("below nested functions")


class Path:
if sys.version_info >= (3, 11):
def joinpath(self): ...

# The .open method comes from pathlib.pyi and should be kept in sync.
@overload
def open(self): ...




def fakehttp():

class FakeHTTPConnection:
if mock_close:
def close(self):
pass
FakeHTTPConnection.fakedata = fakedata





if True:
if False:
def x():
def y():
pass
#comment
print()


# NOTE: Please keep this the last block in this file. This tests that we don't insert
# empty line(s) at the end of the file due to nested function
if True:
def nested_trailing_function():
pass
67 changes: 4 additions & 63 deletions crates/ruff_python_formatter/src/comments/placement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,9 @@ fn handle_end_of_line_comment_around_body<'a>(
// ```
// The first earlier branch filters out ambiguities e.g. around try-except-finally.
if let Some(preceding) = comment.preceding_node() {
if let Some(last_child) = last_child_in_body(preceding) {
if let Some(last_child) = preceding.last_child_in_body() {
let innermost_child =
std::iter::successors(Some(last_child), |parent| last_child_in_body(*parent))
std::iter::successors(Some(last_child), AnyNodeRef::last_child_in_body)
.last()
.unwrap_or(last_child);
return CommentPlacement::trailing(innermost_child, comment);
Expand Down Expand Up @@ -670,7 +670,7 @@ fn handle_own_line_comment_after_branch<'a>(
preceding: AnyNodeRef<'a>,
locator: &Locator,
) -> CommentPlacement<'a> {
let Some(last_child) = last_child_in_body(preceding) else {
let Some(last_child) = preceding.last_child_in_body() else {
return CommentPlacement::Default(comment);
};

Expand Down Expand Up @@ -734,7 +734,7 @@ fn handle_own_line_comment_after_branch<'a>(
return CommentPlacement::trailing(last_child_in_parent, comment);
}
Ordering::Greater => {
if let Some(nested_child) = last_child_in_body(last_child_in_parent) {
if let Some(nested_child) = last_child_in_parent.last_child_in_body() {
// The comment belongs to the inner block.
parent = Some(last_child_in_parent);
last_child_in_parent = nested_child;
Expand Down Expand Up @@ -2176,65 +2176,6 @@ where
right.is_some_and(|right| left.ptr_eq(right.into()))
}

/// The last child of the last branch, if the node has multiple branches.
fn last_child_in_body(node: AnyNodeRef) -> Option<AnyNodeRef> {
let body = match node {
AnyNodeRef::StmtFunctionDef(ast::StmtFunctionDef { body, .. })
| AnyNodeRef::StmtClassDef(ast::StmtClassDef { body, .. })
| AnyNodeRef::StmtWith(ast::StmtWith { body, .. })
| AnyNodeRef::MatchCase(MatchCase { body, .. })
| AnyNodeRef::ExceptHandlerExceptHandler(ast::ExceptHandlerExceptHandler {
body, ..
})
| AnyNodeRef::ElifElseClause(ast::ElifElseClause { body, .. }) => body,
AnyNodeRef::StmtIf(ast::StmtIf {
body,
elif_else_clauses,
..
}) => elif_else_clauses.last().map_or(body, |clause| &clause.body),

AnyNodeRef::StmtFor(ast::StmtFor { body, orelse, .. })
| AnyNodeRef::StmtWhile(ast::StmtWhile { body, orelse, .. }) => {
if orelse.is_empty() {
body
} else {
orelse
}
}

AnyNodeRef::StmtMatch(ast::StmtMatch { cases, .. }) => {
return cases.last().map(AnyNodeRef::from);
}

AnyNodeRef::StmtTry(ast::StmtTry {
body,
handlers,
orelse,
finalbody,
..
}) => {
if finalbody.is_empty() {
if orelse.is_empty() {
if handlers.is_empty() {
body
} else {
return handlers.last().map(AnyNodeRef::from);
}
} else {
orelse
}
} else {
finalbody
}
}

// Not a node that contains an indented child node.
_ => return None,
};

body.last().map(AnyNodeRef::from)
}

/// Returns `true` if `statement` is the first statement in an alternate `body` (e.g. the else of an if statement)
fn is_first_statement_in_alternate_body(statement: AnyNodeRef, has_body: AnyNodeRef) -> bool {
match has_body {
Expand Down
60 changes: 56 additions & 4 deletions crates/ruff_python_formatter/src/statement/suite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,65 @@ impl FormatRule<Suite, PyFormatContext<'_>> for FormatSuite {
while let Some(following) = iter.next() {
let following_comments = comments.leading_dangling_trailing(following);

let needs_empty_lines = if is_class_or_function_definition(following) {
// Here we insert empty lines even if the preceding has a trailing own line comment
true
} else {
// Find nested class or function definitions that need an empty line after them.
//
// ```python
// def f():
// if True:
//
// def double(s):
// return s + s
//
// print("below function")
// ```
std::iter::successors(
Some(AnyNodeRef::from(preceding)),
AnyNodeRef::last_child_in_body,
)
.take_while(|last_child|
// If there is a comment between preceding and following the empty lines were
// inserted before the comment by preceding and there are no extra empty lines
// after the comment.
// ```python
// class Test:
// def a(self):
// pass
// # trailing comment
//
//
// # two lines before, one line after
//
// c = 30
// ````
// This also includes nested class/function definitions, so we stop recursing
// once we see a node with a trailing own line comment:
// ```python
// def f():
// if True:
//
// def double(s):
// return s + s
//
// # nested trailing own line comment
// print("below function with trailing own line comment")
// ```
!comments.has_trailing_own_line(*last_child))
.any(|last_child| {
matches!(
last_child,
AnyNodeRef::StmtFunctionDef(_) | AnyNodeRef::StmtClassDef(_)
)
})
};

// Add empty lines before and after a function or class definition. If the preceding
// node is a function or class, and contains trailing comments, then the statement
// itself will add the requisite empty lines when formatting its comments.
if (is_class_or_function_definition(preceding)
&& !preceding_comments.has_trailing_own_line())
|| is_class_or_function_definition(following)
{
if needs_empty_lines {
if source_type.is_stub() {
stub_file_empty_lines(
self.kind,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,14 @@ with hmm_but_this_should_get_two_preceding_newlines():
elif os.name == "nt":
try:
import msvcrt
@@ -54,12 +53,10 @@
@@ -54,7 +53,6 @@
class IHopeYouAreHavingALovelyDay:
def __call__(self):
print("i_should_be_followed_by_only_one_newline")
-
else:
def foo():
pass
-
with hmm_but_this_should_get_two_preceding_newlines():
pass
```

## Ruff Output
Expand Down Expand Up @@ -151,6 +146,7 @@ else:
def foo():
pass
with hmm_but_this_should_get_two_preceding_newlines():
pass
```
Expand Down
Loading

0 comments on commit 0c3123e

Please sign in to comment.