diff --git a/crates/ruff_python_ast/src/expression.rs b/crates/ruff_python_ast/src/expression.rs new file mode 100644 index 00000000000000..1f3d72f46831d3 --- /dev/null +++ b/crates/ruff_python_ast/src/expression.rs @@ -0,0 +1,564 @@ +use crate::node::AnyNodeRef; +use crate::visitor::preorder::PreorderVisitor; +use ruff_text_size::TextRange; +use rustpython_ast::{Expr, Ranged}; +use rustpython_parser::ast; + +/// Unowned pendant to [`ast::Expr`] that stores a reference instead of a owned value. +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum ExpressionRef<'a> { + BoolOp(&'a ast::ExprBoolOp), + NamedExpr(&'a ast::ExprNamedExpr), + BinOp(&'a ast::ExprBinOp), + UnaryOp(&'a ast::ExprUnaryOp), + Lambda(&'a ast::ExprLambda), + IfExp(&'a ast::ExprIfExp), + Dict(&'a ast::ExprDict), + Set(&'a ast::ExprSet), + ListComp(&'a ast::ExprListComp), + SetComp(&'a ast::ExprSetComp), + DictComp(&'a ast::ExprDictComp), + GeneratorExp(&'a ast::ExprGeneratorExp), + Await(&'a ast::ExprAwait), + Yield(&'a ast::ExprYield), + YieldFrom(&'a ast::ExprYieldFrom), + Compare(&'a ast::ExprCompare), + Call(&'a ast::ExprCall), + FormattedValue(&'a ast::ExprFormattedValue), + JoinedStr(&'a ast::ExprJoinedStr), + Constant(&'a ast::ExprConstant), + Attribute(&'a ast::ExprAttribute), + Subscript(&'a ast::ExprSubscript), + Starred(&'a ast::ExprStarred), + Name(&'a ast::ExprName), + List(&'a ast::ExprList), + Tuple(&'a ast::ExprTuple), + Slice(&'a ast::ExprSlice), +} + +impl<'a> From<&'a Box> for ExpressionRef<'a> { + fn from(value: &'a Box) -> Self { + ExpressionRef::from(value.as_ref()) + } +} + +impl<'a> From<&'a Expr> for ExpressionRef<'a> { + fn from(value: &'a Expr) -> Self { + match value { + Expr::BoolOp(value) => ExpressionRef::BoolOp(value), + Expr::NamedExpr(value) => ExpressionRef::NamedExpr(value), + Expr::BinOp(value) => ExpressionRef::BinOp(value), + Expr::UnaryOp(value) => ExpressionRef::UnaryOp(value), + Expr::Lambda(value) => ExpressionRef::Lambda(value), + Expr::IfExp(value) => ExpressionRef::IfExp(value), + Expr::Dict(value) => ExpressionRef::Dict(value), + Expr::Set(value) => ExpressionRef::Set(value), + Expr::ListComp(value) => ExpressionRef::ListComp(value), + Expr::SetComp(value) => ExpressionRef::SetComp(value), + Expr::DictComp(value) => ExpressionRef::DictComp(value), + Expr::GeneratorExp(value) => ExpressionRef::GeneratorExp(value), + Expr::Await(value) => ExpressionRef::Await(value), + Expr::Yield(value) => ExpressionRef::Yield(value), + Expr::YieldFrom(value) => ExpressionRef::YieldFrom(value), + Expr::Compare(value) => ExpressionRef::Compare(value), + Expr::Call(value) => ExpressionRef::Call(value), + Expr::FormattedValue(value) => ExpressionRef::FormattedValue(value), + Expr::JoinedStr(value) => ExpressionRef::JoinedStr(value), + Expr::Constant(value) => ExpressionRef::Constant(value), + Expr::Attribute(value) => ExpressionRef::Attribute(value), + Expr::Subscript(value) => ExpressionRef::Subscript(value), + Expr::Starred(value) => ExpressionRef::Starred(value), + Expr::Name(value) => ExpressionRef::Name(value), + Expr::List(value) => ExpressionRef::List(value), + Expr::Tuple(value) => ExpressionRef::Tuple(value), + Expr::Slice(value) => ExpressionRef::Slice(value), + } + } +} + +impl<'a> From<&'a ast::ExprBoolOp> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprBoolOp) -> Self { + Self::BoolOp(value) + } +} +impl<'a> From<&'a ast::ExprNamedExpr> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprNamedExpr) -> Self { + Self::NamedExpr(value) + } +} +impl<'a> From<&'a ast::ExprBinOp> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprBinOp) -> Self { + Self::BinOp(value) + } +} +impl<'a> From<&'a ast::ExprUnaryOp> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprUnaryOp) -> Self { + Self::UnaryOp(value) + } +} +impl<'a> From<&'a ast::ExprLambda> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprLambda) -> Self { + Self::Lambda(value) + } +} +impl<'a> From<&'a ast::ExprIfExp> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprIfExp) -> Self { + Self::IfExp(value) + } +} +impl<'a> From<&'a ast::ExprDict> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprDict) -> Self { + Self::Dict(value) + } +} +impl<'a> From<&'a ast::ExprSet> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprSet) -> Self { + Self::Set(value) + } +} +impl<'a> From<&'a ast::ExprListComp> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprListComp) -> Self { + Self::ListComp(value) + } +} +impl<'a> From<&'a ast::ExprSetComp> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprSetComp) -> Self { + Self::SetComp(value) + } +} +impl<'a> From<&'a ast::ExprDictComp> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprDictComp) -> Self { + Self::DictComp(value) + } +} +impl<'a> From<&'a ast::ExprGeneratorExp> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprGeneratorExp) -> Self { + Self::GeneratorExp(value) + } +} +impl<'a> From<&'a ast::ExprAwait> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprAwait) -> Self { + Self::Await(value) + } +} +impl<'a> From<&'a ast::ExprYield> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprYield) -> Self { + Self::Yield(value) + } +} +impl<'a> From<&'a ast::ExprYieldFrom> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprYieldFrom) -> Self { + Self::YieldFrom(value) + } +} +impl<'a> From<&'a ast::ExprCompare> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprCompare) -> Self { + Self::Compare(value) + } +} +impl<'a> From<&'a ast::ExprCall> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprCall) -> Self { + Self::Call(value) + } +} +impl<'a> From<&'a ast::ExprFormattedValue> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprFormattedValue) -> Self { + Self::FormattedValue(value) + } +} +impl<'a> From<&'a ast::ExprJoinedStr> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprJoinedStr) -> Self { + Self::JoinedStr(value) + } +} +impl<'a> From<&'a ast::ExprConstant> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprConstant) -> Self { + Self::Constant(value) + } +} +impl<'a> From<&'a ast::ExprAttribute> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprAttribute) -> Self { + Self::Attribute(value) + } +} +impl<'a> From<&'a ast::ExprSubscript> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprSubscript) -> Self { + Self::Subscript(value) + } +} +impl<'a> From<&'a ast::ExprStarred> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprStarred) -> Self { + Self::Starred(value) + } +} +impl<'a> From<&'a ast::ExprName> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprName) -> Self { + Self::Name(value) + } +} +impl<'a> From<&'a ast::ExprList> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprList) -> Self { + Self::List(value) + } +} +impl<'a> From<&'a ast::ExprTuple> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprTuple) -> Self { + Self::Tuple(value) + } +} +impl<'a> From<&'a ast::ExprSlice> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprSlice) -> Self { + Self::Slice(value) + } +} + +impl<'a> From> for AnyNodeRef<'a> { + fn from(value: ExpressionRef<'a>) -> Self { + match value { + ExpressionRef::BoolOp(expression) => AnyNodeRef::ExprBoolOp(expression), + ExpressionRef::NamedExpr(expression) => AnyNodeRef::ExprNamedExpr(expression), + ExpressionRef::BinOp(expression) => AnyNodeRef::ExprBinOp(expression), + ExpressionRef::UnaryOp(expression) => AnyNodeRef::ExprUnaryOp(expression), + ExpressionRef::Lambda(expression) => AnyNodeRef::ExprLambda(expression), + ExpressionRef::IfExp(expression) => AnyNodeRef::ExprIfExp(expression), + ExpressionRef::Dict(expression) => AnyNodeRef::ExprDict(expression), + ExpressionRef::Set(expression) => AnyNodeRef::ExprSet(expression), + ExpressionRef::ListComp(expression) => AnyNodeRef::ExprListComp(expression), + ExpressionRef::SetComp(expression) => AnyNodeRef::ExprSetComp(expression), + ExpressionRef::DictComp(expression) => AnyNodeRef::ExprDictComp(expression), + ExpressionRef::GeneratorExp(expression) => AnyNodeRef::ExprGeneratorExp(expression), + ExpressionRef::Await(expression) => AnyNodeRef::ExprAwait(expression), + ExpressionRef::Yield(expression) => AnyNodeRef::ExprYield(expression), + ExpressionRef::YieldFrom(expression) => AnyNodeRef::ExprYieldFrom(expression), + ExpressionRef::Compare(expression) => AnyNodeRef::ExprCompare(expression), + ExpressionRef::Call(expression) => AnyNodeRef::ExprCall(expression), + ExpressionRef::FormattedValue(expression) => AnyNodeRef::ExprFormattedValue(expression), + ExpressionRef::JoinedStr(expression) => AnyNodeRef::ExprJoinedStr(expression), + ExpressionRef::Constant(expression) => AnyNodeRef::ExprConstant(expression), + ExpressionRef::Attribute(expression) => AnyNodeRef::ExprAttribute(expression), + ExpressionRef::Subscript(expression) => AnyNodeRef::ExprSubscript(expression), + ExpressionRef::Starred(expression) => AnyNodeRef::ExprStarred(expression), + ExpressionRef::Name(expression) => AnyNodeRef::ExprName(expression), + ExpressionRef::List(expression) => AnyNodeRef::ExprList(expression), + ExpressionRef::Tuple(expression) => AnyNodeRef::ExprTuple(expression), + ExpressionRef::Slice(expression) => AnyNodeRef::ExprSlice(expression), + } + } +} + +impl Ranged for ExpressionRef<'_> { + fn range(&self) -> TextRange { + match self { + ExpressionRef::BoolOp(expression) => expression.range(), + ExpressionRef::NamedExpr(expression) => expression.range(), + ExpressionRef::BinOp(expression) => expression.range(), + ExpressionRef::UnaryOp(expression) => expression.range(), + ExpressionRef::Lambda(expression) => expression.range(), + ExpressionRef::IfExp(expression) => expression.range(), + ExpressionRef::Dict(expression) => expression.range(), + ExpressionRef::Set(expression) => expression.range(), + ExpressionRef::ListComp(expression) => expression.range(), + ExpressionRef::SetComp(expression) => expression.range(), + ExpressionRef::DictComp(expression) => expression.range(), + ExpressionRef::GeneratorExp(expression) => expression.range(), + ExpressionRef::Await(expression) => expression.range(), + ExpressionRef::Yield(expression) => expression.range(), + ExpressionRef::YieldFrom(expression) => expression.range(), + ExpressionRef::Compare(expression) => expression.range(), + ExpressionRef::Call(expression) => expression.range(), + ExpressionRef::FormattedValue(expression) => expression.range(), + ExpressionRef::JoinedStr(expression) => expression.range(), + ExpressionRef::Constant(expression) => expression.range(), + ExpressionRef::Attribute(expression) => expression.range(), + ExpressionRef::Subscript(expression) => expression.range(), + ExpressionRef::Starred(expression) => expression.range(), + ExpressionRef::Name(expression) => expression.range(), + ExpressionRef::List(expression) => expression.range(), + ExpressionRef::Tuple(expression) => expression.range(), + ExpressionRef::Slice(expression) => expression.range(), + } + } +} + +pub fn walk_expression_ref<'a, V>(visitor: &mut V, expr: ExpressionRef<'a>) +where + V: PreorderVisitor<'a> + ?Sized, +{ + match expr { + ExpressionRef::BoolOp(ast::ExprBoolOp { + op, + values, + range: _range, + }) => match values.as_slice() { + [left, rest @ ..] => { + visitor.visit_expr(left); + visitor.visit_bool_op(op); + for expr in rest { + visitor.visit_expr(expr); + } + } + [] => { + visitor.visit_bool_op(op); + } + }, + + ExpressionRef::NamedExpr(ast::ExprNamedExpr { + target, + value, + range: _range, + }) => { + visitor.visit_expr(target); + visitor.visit_expr(value); + } + + ExpressionRef::BinOp(ast::ExprBinOp { + left, + op, + right, + range: _range, + }) => { + visitor.visit_expr(left); + visitor.visit_operator(op); + visitor.visit_expr(right); + } + + ExpressionRef::UnaryOp(ast::ExprUnaryOp { + op, + operand, + range: _range, + }) => { + visitor.visit_unary_op(op); + visitor.visit_expr(operand); + } + + ExpressionRef::Lambda(ast::ExprLambda { + args, + body, + range: _range, + }) => { + visitor.visit_arguments(args); + visitor.visit_expr(body); + } + + ExpressionRef::IfExp(ast::ExprIfExp { + test, + body, + orelse, + range: _range, + }) => { + // `body if test else orelse` + visitor.visit_expr(body); + visitor.visit_expr(test); + visitor.visit_expr(orelse); + } + + ExpressionRef::Dict(ast::ExprDict { + keys, + values, + range: _range, + }) => { + for (key, value) in keys.iter().zip(values) { + if let Some(key) = key { + visitor.visit_expr(key); + } + visitor.visit_expr(value); + } + } + + ExpressionRef::Set(ast::ExprSet { + elts, + range: _range, + }) => { + for expr in elts { + visitor.visit_expr(expr); + } + } + + ExpressionRef::ListComp(ast::ExprListComp { + elt, + generators, + range: _range, + }) => { + visitor.visit_expr(elt); + for comprehension in generators { + visitor.visit_comprehension(comprehension); + } + } + + ExpressionRef::SetComp(ast::ExprSetComp { + elt, + generators, + range: _range, + }) => { + visitor.visit_expr(elt); + for comprehension in generators { + visitor.visit_comprehension(comprehension); + } + } + + ExpressionRef::DictComp(ast::ExprDictComp { + key, + value, + generators, + range: _range, + }) => { + visitor.visit_expr(key); + visitor.visit_expr(value); + + for comprehension in generators { + visitor.visit_comprehension(comprehension); + } + } + + ExpressionRef::GeneratorExp(ast::ExprGeneratorExp { + elt, + generators, + range: _range, + }) => { + visitor.visit_expr(elt); + for comprehension in generators { + visitor.visit_comprehension(comprehension); + } + } + + ExpressionRef::Await(ast::ExprAwait { + value, + range: _range, + }) + | ExpressionRef::YieldFrom(ast::ExprYieldFrom { + value, + range: _range, + }) => visitor.visit_expr(value), + + ExpressionRef::Yield(ast::ExprYield { + value, + range: _range, + }) => { + if let Some(expr) = value { + visitor.visit_expr(expr); + } + } + + ExpressionRef::Compare(ast::ExprCompare { + left, + ops, + comparators, + range: _range, + }) => { + visitor.visit_expr(left); + + for (op, comparator) in ops.iter().zip(comparators) { + visitor.visit_cmp_op(op); + visitor.visit_expr(comparator); + } + } + + ExpressionRef::Call(ast::ExprCall { + func, + args, + keywords, + range: _range, + }) => { + visitor.visit_expr(func); + for expr in args { + visitor.visit_expr(expr); + } + for keyword in keywords { + visitor.visit_keyword(keyword); + } + } + + ExpressionRef::FormattedValue(ast::ExprFormattedValue { + value, format_spec, .. + }) => { + visitor.visit_expr(value); + + if let Some(expr) = format_spec { + visitor.visit_format_spec(expr); + } + } + + ExpressionRef::JoinedStr(ast::ExprJoinedStr { + values, + range: _range, + }) => { + for expr in values { + visitor.visit_expr(expr); + } + } + + ExpressionRef::Constant(ast::ExprConstant { + value, + range: _, + kind: _, + }) => visitor.visit_constant(value), + + ExpressionRef::Attribute(ast::ExprAttribute { + value, + attr: _, + ctx: _, + range: _, + }) => { + visitor.visit_expr(value); + } + + ExpressionRef::Subscript(ast::ExprSubscript { + value, + slice, + ctx: _, + range: _range, + }) => { + visitor.visit_expr(value); + visitor.visit_expr(slice); + } + ExpressionRef::Starred(ast::ExprStarred { + value, + ctx: _, + range: _range, + }) => { + visitor.visit_expr(value); + } + + ExpressionRef::Name(ast::ExprName { + id: _, + ctx: _, + range: _, + }) => {} + + ExpressionRef::List(ast::ExprList { + elts, + ctx: _, + range: _range, + }) => { + for expr in elts { + visitor.visit_expr(expr); + } + } + ExpressionRef::Tuple(ast::ExprTuple { + elts, + ctx: _, + range: _range, + }) => { + for expr in elts { + visitor.visit_expr(expr); + } + } + + ExpressionRef::Slice(ast::ExprSlice { + lower, + upper, + step, + range: _range, + }) => { + if let Some(expr) = lower { + visitor.visit_expr(expr); + } + if let Some(expr) = upper { + visitor.visit_expr(expr); + } + if let Some(expr) = step { + visitor.visit_expr(expr); + } + } + } +} diff --git a/crates/ruff_python_ast/src/lib.rs b/crates/ruff_python_ast/src/lib.rs index 72928ec9e96dd9..b8b7d883f49755 100644 --- a/crates/ruff_python_ast/src/lib.rs +++ b/crates/ruff_python_ast/src/lib.rs @@ -3,6 +3,7 @@ pub mod call_path; pub mod cast; pub mod comparable; pub mod docstrings; +pub mod expression; pub mod function; pub mod hashable; pub mod helpers; diff --git a/crates/ruff_python_ast/src/visitor/preorder.rs b/crates/ruff_python_ast/src/visitor/preorder.rs index 4a21e82bade151..638290bc218a2b 100644 --- a/crates/ruff_python_ast/src/visitor/preorder.rs +++ b/crates/ruff_python_ast/src/visitor/preorder.rs @@ -1,3 +1,4 @@ +use crate::expression::{walk_expression_ref, ExpressionRef}; use rustpython_ast::{ArgWithDefault, Mod, TypeIgnore}; use rustpython_parser::ast::{ self, Alias, Arg, Arguments, BoolOp, CmpOp, Comprehension, Constant, Decorator, ExceptHandler, @@ -414,283 +415,7 @@ pub fn walk_expr<'a, V>(visitor: &mut V, expr: &'a Expr) where V: PreorderVisitor<'a> + ?Sized, { - match expr { - Expr::BoolOp(ast::ExprBoolOp { - op, - values, - range: _range, - }) => match values.as_slice() { - [left, rest @ ..] => { - visitor.visit_expr(left); - visitor.visit_bool_op(op); - for expr in rest { - visitor.visit_expr(expr); - } - } - [] => { - visitor.visit_bool_op(op); - } - }, - - Expr::NamedExpr(ast::ExprNamedExpr { - target, - value, - range: _range, - }) => { - visitor.visit_expr(target); - visitor.visit_expr(value); - } - - Expr::BinOp(ast::ExprBinOp { - left, - op, - right, - range: _range, - }) => { - visitor.visit_expr(left); - visitor.visit_operator(op); - visitor.visit_expr(right); - } - - Expr::UnaryOp(ast::ExprUnaryOp { - op, - operand, - range: _range, - }) => { - visitor.visit_unary_op(op); - visitor.visit_expr(operand); - } - - Expr::Lambda(ast::ExprLambda { - args, - body, - range: _range, - }) => { - visitor.visit_arguments(args); - visitor.visit_expr(body); - } - - Expr::IfExp(ast::ExprIfExp { - test, - body, - orelse, - range: _range, - }) => { - // `body if test else orelse` - visitor.visit_expr(body); - visitor.visit_expr(test); - visitor.visit_expr(orelse); - } - - Expr::Dict(ast::ExprDict { - keys, - values, - range: _range, - }) => { - for (key, value) in keys.iter().zip(values) { - if let Some(key) = key { - visitor.visit_expr(key); - } - visitor.visit_expr(value); - } - } - - Expr::Set(ast::ExprSet { - elts, - range: _range, - }) => { - for expr in elts { - visitor.visit_expr(expr); - } - } - - Expr::ListComp(ast::ExprListComp { - elt, - generators, - range: _range, - }) => { - visitor.visit_expr(elt); - for comprehension in generators { - visitor.visit_comprehension(comprehension); - } - } - - Expr::SetComp(ast::ExprSetComp { - elt, - generators, - range: _range, - }) => { - visitor.visit_expr(elt); - for comprehension in generators { - visitor.visit_comprehension(comprehension); - } - } - - Expr::DictComp(ast::ExprDictComp { - key, - value, - generators, - range: _range, - }) => { - visitor.visit_expr(key); - visitor.visit_expr(value); - - for comprehension in generators { - visitor.visit_comprehension(comprehension); - } - } - - Expr::GeneratorExp(ast::ExprGeneratorExp { - elt, - generators, - range: _range, - }) => { - visitor.visit_expr(elt); - for comprehension in generators { - visitor.visit_comprehension(comprehension); - } - } - - Expr::Await(ast::ExprAwait { - value, - range: _range, - }) - | Expr::YieldFrom(ast::ExprYieldFrom { - value, - range: _range, - }) => visitor.visit_expr(value), - - Expr::Yield(ast::ExprYield { - value, - range: _range, - }) => { - if let Some(expr) = value { - visitor.visit_expr(expr); - } - } - - Expr::Compare(ast::ExprCompare { - left, - ops, - comparators, - range: _range, - }) => { - visitor.visit_expr(left); - - for (op, comparator) in ops.iter().zip(comparators) { - visitor.visit_cmp_op(op); - visitor.visit_expr(comparator); - } - } - - Expr::Call(ast::ExprCall { - func, - args, - keywords, - range: _range, - }) => { - visitor.visit_expr(func); - for expr in args { - visitor.visit_expr(expr); - } - for keyword in keywords { - visitor.visit_keyword(keyword); - } - } - - Expr::FormattedValue(ast::ExprFormattedValue { - value, format_spec, .. - }) => { - visitor.visit_expr(value); - - if let Some(expr) = format_spec { - visitor.visit_format_spec(expr); - } - } - - Expr::JoinedStr(ast::ExprJoinedStr { - values, - range: _range, - }) => { - for expr in values { - visitor.visit_expr(expr); - } - } - - Expr::Constant(ast::ExprConstant { - value, - range: _, - kind: _, - }) => visitor.visit_constant(value), - - Expr::Attribute(ast::ExprAttribute { - value, - attr: _, - ctx: _, - range: _, - }) => { - visitor.visit_expr(value); - } - - Expr::Subscript(ast::ExprSubscript { - value, - slice, - ctx: _, - range: _range, - }) => { - visitor.visit_expr(value); - visitor.visit_expr(slice); - } - Expr::Starred(ast::ExprStarred { - value, - ctx: _, - range: _range, - }) => { - visitor.visit_expr(value); - } - - Expr::Name(ast::ExprName { - id: _, - ctx: _, - range: _, - }) => {} - - Expr::List(ast::ExprList { - elts, - ctx: _, - range: _range, - }) => { - for expr in elts { - visitor.visit_expr(expr); - } - } - Expr::Tuple(ast::ExprTuple { - elts, - ctx: _, - range: _range, - }) => { - for expr in elts { - visitor.visit_expr(expr); - } - } - - Expr::Slice(ast::ExprSlice { - lower, - upper, - step, - range: _range, - }) => { - if let Some(expr) = lower { - visitor.visit_expr(expr); - } - if let Some(expr) = upper { - visitor.visit_expr(expr); - } - if let Some(expr) = step { - visitor.visit_expr(expr); - } - } - } + walk_expression_ref(visitor, ExpressionRef::from(expr)) } pub fn walk_constant<'a, V>(visitor: &mut V, constant: &'a Constant) diff --git a/crates/ruff_python_formatter/src/expression/expr_bin_op.rs b/crates/ruff_python_formatter/src/expression/expr_bin_op.rs index f6f051066a80a6..77ed14f9331bf5 100644 --- a/crates/ruff_python_formatter/src/expression/expr_bin_op.rs +++ b/crates/ruff_python_formatter/src/expression/expr_bin_op.rs @@ -7,7 +7,7 @@ use crate::expression::Parentheses; use crate::prelude::*; use crate::FormatNodeRule; use ruff_formatter::{write, FormatOwnedWithRule, FormatRefWithRule, FormatRuleWithOptions}; -use ruff_python_ast::node::AstNode; +use ruff_python_ast::expression::ExpressionRef; use rustpython_parser::ast::{ Constant, Expr, ExprAttribute, ExprBinOp, ExprConstant, ExprUnaryOp, Operator, UnaryOp, }; @@ -36,7 +36,7 @@ impl FormatNodeRule for FormatExprBinOp { let source = f.context().source(); let binary_chain: SmallVec<[&ExprBinOp; 4]> = iter::successors(Some(item), |parent| { parent.left.as_bin_op_expr().and_then(|bin_expression| { - if is_expression_parenthesized(bin_expression.as_any_node_ref(), source) { + if is_expression_parenthesized(ExpressionRef::from(bin_expression), source) { None } else { Some(bin_expression) diff --git a/crates/ruff_python_formatter/src/expression/parentheses.rs b/crates/ruff_python_formatter/src/expression/parentheses.rs index a1e0a38a6ad549..3bcb8d7706da10 100644 --- a/crates/ruff_python_formatter/src/expression/parentheses.rs +++ b/crates/ruff_python_formatter/src/expression/parentheses.rs @@ -3,7 +3,7 @@ use crate::prelude::*; use crate::trivia::{first_non_trivia_token, first_non_trivia_token_rev, Token, TokenKind}; use ruff_formatter::prelude::tag::Condition; use ruff_formatter::{format_args, Argument, Arguments}; -use ruff_python_ast::node::AnyNodeRef; +use ruff_python_ast::expression::ExpressionRef; use rustpython_parser::ast::Ranged; pub(crate) trait NeedsParentheses { @@ -15,15 +15,10 @@ pub(crate) trait NeedsParentheses { } pub(super) fn default_expression_needs_parentheses( - node: AnyNodeRef, + node: ExpressionRef, parenthesize: Parenthesize, context: &PyFormatContext, ) -> Parentheses { - debug_assert!( - node.is_expression(), - "Should only be called for expressions" - ); - #[allow(clippy::if_same_then_else)] if parenthesize.is_always() { Parentheses::Always @@ -48,7 +43,7 @@ pub(super) fn default_expression_needs_parentheses( } } -fn can_omit_optional_parentheses(expr: AnyNodeRef, context: &PyFormatContext) -> bool { +fn can_omit_optional_parentheses(expr: ExpressionRef, context: &PyFormatContext) -> bool { if context.comments().has_leading_comments(expr) { false } else { @@ -56,6 +51,8 @@ fn can_omit_optional_parentheses(expr: AnyNodeRef, context: &PyFormatContext) -> } } +// fn has_magic_comma(expr: AnyNodeRef) + /// Configures if the expression should be parenthesized. #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] pub enum Parenthesize { @@ -113,7 +110,7 @@ pub enum Parentheses { Never, } -pub(crate) fn is_expression_parenthesized(expr: AnyNodeRef, contents: &str) -> bool { +pub(crate) fn is_expression_parenthesized(expr: ExpressionRef, contents: &str) -> bool { matches!( first_non_trivia_token(expr.end(), contents), Some(Token { diff --git a/crates/ruff_python_formatter/src/statement/stmt_expr.rs b/crates/ruff_python_formatter/src/statement/stmt_expr.rs index 4e415f77cf8e8c..753d4dbc928cbe 100644 --- a/crates/ruff_python_formatter/src/statement/stmt_expr.rs +++ b/crates/ruff_python_formatter/src/statement/stmt_expr.rs @@ -2,6 +2,7 @@ use crate::expression::parentheses::{is_expression_parenthesized, Parenthesize}; use crate::expression::string::StringLayout; use crate::prelude::*; use crate::FormatNodeRule; +use ruff_python_ast::expression::ExpressionRef; use rustpython_parser::ast::StmtExpr; #[derive(Default)] @@ -13,7 +14,10 @@ impl FormatNodeRule for FormatStmtExpr { if let Some(constant) = value.as_constant_expr() { if constant.value.is_str() - && !is_expression_parenthesized(value.as_ref().into(), f.context().source()) + && !is_expression_parenthesized( + ExpressionRef::from(value.as_ref()), + f.context().source(), + ) { return constant.format().with_options(StringLayout::Flat).fmt(f); }