Skip to content

Commit 7e9da40

Browse files
committed
Auto merge of #15700 - rmehri01:15694_iterator_demorgan, r=Veykril
feat: add assist for applying De Morgan's law to `Iterator::all` and `Iterator::any` This PR adds an assist for transforming expressions of the form `!iter.any(|x| predicate(x))` into `iter.all(|x| !predicate(x))` and vice versa. [IteratorDeMorgans.webm](https://github.com/rust-lang/rust-analyzer/assets/52933714/aad1a299-6620-432b-9106-aafd2a7fa9f5) Closes #15694
2 parents 36be913 + c266387 commit 7e9da40

File tree

3 files changed

+357
-2
lines changed

3 files changed

+357
-2
lines changed

crates/ide-assists/src/handlers/apply_demorgan.rs

+332-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
use std::collections::VecDeque;
22

3+
use ide_db::{
4+
assists::GroupLabel,
5+
famous_defs::FamousDefs,
6+
source_change::SourceChangeBuilder,
7+
syntax_helpers::node_ext::{for_each_tail_expr, walk_expr},
8+
};
39
use syntax::{
4-
ast::{self, AstNode, Expr::BinExpr},
10+
ast::{self, make, AstNode, Expr::BinExpr, HasArgList},
511
ted::{self, Position},
612
SyntaxKind,
713
};
@@ -89,7 +95,8 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
8995

9096
let dm_lhs = demorganed.lhs()?;
9197

92-
acc.add(
98+
acc.add_group(
99+
&GroupLabel("Apply De Morgan's law".to_string()),
93100
AssistId("apply_demorgan", AssistKind::RefactorRewrite),
94101
"Apply De Morgan's law",
95102
op_range,
@@ -143,6 +150,127 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
143150
)
144151
}
145152

153+
// Assist: apply_demorgan_iterator
154+
//
155+
// Apply https://en.wikipedia.org/wiki/De_Morgan%27s_laws[De Morgan's law] to
156+
// `Iterator::all` and `Iterator::any`.
157+
//
158+
// This transforms expressions of the form `!iter.any(|x| predicate(x))` into
159+
// `iter.all(|x| !predicate(x))` and vice versa. This also works the other way for
160+
// `Iterator::all` into `Iterator::any`.
161+
//
162+
// ```
163+
// # //- minicore: iterator
164+
// fn main() {
165+
// let arr = [1, 2, 3];
166+
// if !arr.into_iter().$0any(|num| num == 4) {
167+
// println!("foo");
168+
// }
169+
// }
170+
// ```
171+
// ->
172+
// ```
173+
// fn main() {
174+
// let arr = [1, 2, 3];
175+
// if arr.into_iter().all(|num| num != 4) {
176+
// println!("foo");
177+
// }
178+
// }
179+
// ```
180+
pub(crate) fn apply_demorgan_iterator(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
181+
let method_call: ast::MethodCallExpr = ctx.find_node_at_offset()?;
182+
let (name, arg_expr) = validate_method_call_expr(ctx, &method_call)?;
183+
184+
let ast::Expr::ClosureExpr(closure_expr) = arg_expr else { return None };
185+
let closure_body = closure_expr.body()?;
186+
187+
let op_range = method_call.syntax().text_range();
188+
let label = format!("Apply De Morgan's law to `Iterator::{}`", name.text().as_str());
189+
acc.add_group(
190+
&GroupLabel("Apply De Morgan's law".to_string()),
191+
AssistId("apply_demorgan_iterator", AssistKind::RefactorRewrite),
192+
label,
193+
op_range,
194+
|edit| {
195+
// replace the method name
196+
let new_name = match name.text().as_str() {
197+
"all" => make::name_ref("any"),
198+
"any" => make::name_ref("all"),
199+
_ => unreachable!(),
200+
}
201+
.clone_for_update();
202+
edit.replace_ast(name, new_name);
203+
204+
// negate all tail expressions in the closure body
205+
let tail_cb = &mut |e: &_| tail_cb_impl(edit, e);
206+
walk_expr(&closure_body, &mut |expr| {
207+
if let ast::Expr::ReturnExpr(ret_expr) = expr {
208+
if let Some(ret_expr_arg) = &ret_expr.expr() {
209+
for_each_tail_expr(ret_expr_arg, tail_cb);
210+
}
211+
}
212+
});
213+
for_each_tail_expr(&closure_body, tail_cb);
214+
215+
// negate the whole method call
216+
if let Some(prefix_expr) = method_call
217+
.syntax()
218+
.parent()
219+
.and_then(ast::PrefixExpr::cast)
220+
.filter(|prefix_expr| matches!(prefix_expr.op_kind(), Some(ast::UnaryOp::Not)))
221+
{
222+
edit.delete(
223+
prefix_expr
224+
.op_token()
225+
.expect("prefix expression always has an operator")
226+
.text_range(),
227+
);
228+
} else {
229+
edit.insert(method_call.syntax().text_range().start(), "!");
230+
}
231+
},
232+
)
233+
}
234+
235+
/// Ensures that the method call is to `Iterator::all` or `Iterator::any`.
236+
fn validate_method_call_expr(
237+
ctx: &AssistContext<'_>,
238+
method_call: &ast::MethodCallExpr,
239+
) -> Option<(ast::NameRef, ast::Expr)> {
240+
let name_ref = method_call.name_ref()?;
241+
if name_ref.text() != "all" && name_ref.text() != "any" {
242+
return None;
243+
}
244+
let arg_expr = method_call.arg_list()?.args().next()?;
245+
246+
let sema = &ctx.sema;
247+
248+
let receiver = method_call.receiver()?;
249+
let it_type = sema.type_of_expr(&receiver)?.adjusted();
250+
let module = sema.scope(receiver.syntax())?.module();
251+
let krate = module.krate();
252+
253+
let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?;
254+
it_type.impls_trait(sema.db, iter_trait, &[]).then_some((name_ref, arg_expr))
255+
}
256+
257+
fn tail_cb_impl(edit: &mut SourceChangeBuilder, e: &ast::Expr) {
258+
match e {
259+
ast::Expr::BreakExpr(break_expr) => {
260+
if let Some(break_expr_arg) = break_expr.expr() {
261+
for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(edit, e))
262+
}
263+
}
264+
ast::Expr::ReturnExpr(_) => {
265+
// all return expressions have already been handled by the walk loop
266+
}
267+
e => {
268+
let inverted_body = invert_boolean_expression(e.clone());
269+
edit.replace(e.syntax().text_range(), inverted_body.syntax().text());
270+
}
271+
}
272+
}
273+
146274
#[cfg(test)]
147275
mod tests {
148276
use super::*;
@@ -255,4 +383,206 @@ fn f() { !(S <= S || S < S) }
255383
"fn() { let x = a && b && c; }",
256384
)
257385
}
386+
387+
#[test]
388+
fn demorgan_iterator_any_all_reverse() {
389+
check_assist(
390+
apply_demorgan_iterator,
391+
r#"
392+
//- minicore: iterator
393+
fn main() {
394+
let arr = [1, 2, 3];
395+
if arr.into_iter().all(|num| num $0!= 4) {
396+
println!("foo");
397+
}
398+
}
399+
"#,
400+
r#"
401+
fn main() {
402+
let arr = [1, 2, 3];
403+
if !arr.into_iter().any(|num| num == 4) {
404+
println!("foo");
405+
}
406+
}
407+
"#,
408+
);
409+
}
410+
411+
#[test]
412+
fn demorgan_iterator_all_any() {
413+
check_assist(
414+
apply_demorgan_iterator,
415+
r#"
416+
//- minicore: iterator
417+
fn main() {
418+
let arr = [1, 2, 3];
419+
if !arr.into_iter().$0all(|num| num > 3) {
420+
println!("foo");
421+
}
422+
}
423+
"#,
424+
r#"
425+
fn main() {
426+
let arr = [1, 2, 3];
427+
if arr.into_iter().any(|num| num <= 3) {
428+
println!("foo");
429+
}
430+
}
431+
"#,
432+
);
433+
}
434+
435+
#[test]
436+
fn demorgan_iterator_multiple_terms() {
437+
check_assist(
438+
apply_demorgan_iterator,
439+
r#"
440+
//- minicore: iterator
441+
fn main() {
442+
let arr = [1, 2, 3];
443+
if !arr.into_iter().$0any(|num| num > 3 && num == 23 && num <= 30) {
444+
println!("foo");
445+
}
446+
}
447+
"#,
448+
r#"
449+
fn main() {
450+
let arr = [1, 2, 3];
451+
if arr.into_iter().all(|num| !(num > 3 && num == 23 && num <= 30)) {
452+
println!("foo");
453+
}
454+
}
455+
"#,
456+
);
457+
}
458+
459+
#[test]
460+
fn demorgan_iterator_double_negation() {
461+
check_assist(
462+
apply_demorgan_iterator,
463+
r#"
464+
//- minicore: iterator
465+
fn main() {
466+
let arr = [1, 2, 3];
467+
if !arr.into_iter().$0all(|num| !(num > 3)) {
468+
println!("foo");
469+
}
470+
}
471+
"#,
472+
r#"
473+
fn main() {
474+
let arr = [1, 2, 3];
475+
if arr.into_iter().any(|num| num > 3) {
476+
println!("foo");
477+
}
478+
}
479+
"#,
480+
);
481+
}
482+
483+
#[test]
484+
fn demorgan_iterator_double_parens() {
485+
check_assist(
486+
apply_demorgan_iterator,
487+
r#"
488+
//- minicore: iterator
489+
fn main() {
490+
let arr = [1, 2, 3];
491+
if !arr.into_iter().$0any(|num| (num > 3 && (num == 1 || num == 2))) {
492+
println!("foo");
493+
}
494+
}
495+
"#,
496+
r#"
497+
fn main() {
498+
let arr = [1, 2, 3];
499+
if arr.into_iter().all(|num| !(num > 3 && (num == 1 || num == 2))) {
500+
println!("foo");
501+
}
502+
}
503+
"#,
504+
);
505+
}
506+
507+
#[test]
508+
fn demorgan_iterator_multiline() {
509+
check_assist(
510+
apply_demorgan_iterator,
511+
r#"
512+
//- minicore: iterator
513+
fn main() {
514+
let arr = [1, 2, 3];
515+
if arr
516+
.into_iter()
517+
.all$0(|num| !num.is_negative())
518+
{
519+
println!("foo");
520+
}
521+
}
522+
"#,
523+
r#"
524+
fn main() {
525+
let arr = [1, 2, 3];
526+
if !arr
527+
.into_iter()
528+
.any(|num| num.is_negative())
529+
{
530+
println!("foo");
531+
}
532+
}
533+
"#,
534+
);
535+
}
536+
537+
#[test]
538+
fn demorgan_iterator_block_closure() {
539+
check_assist(
540+
apply_demorgan_iterator,
541+
r#"
542+
//- minicore: iterator
543+
fn main() {
544+
let arr = [-1, 1, 2, 3];
545+
if arr.into_iter().all(|num: i32| {
546+
$0if num.is_positive() {
547+
num <= 3
548+
} else {
549+
num >= -1
550+
}
551+
}) {
552+
println!("foo");
553+
}
554+
}
555+
"#,
556+
r#"
557+
fn main() {
558+
let arr = [-1, 1, 2, 3];
559+
if !arr.into_iter().any(|num: i32| {
560+
if num.is_positive() {
561+
num > 3
562+
} else {
563+
num < -1
564+
}
565+
}) {
566+
println!("foo");
567+
}
568+
}
569+
"#,
570+
);
571+
}
572+
573+
#[test]
574+
fn demorgan_iterator_wrong_method() {
575+
check_assist_not_applicable(
576+
apply_demorgan_iterator,
577+
r#"
578+
//- minicore: iterator
579+
fn main() {
580+
let arr = [1, 2, 3];
581+
if !arr.into_iter().$0map(|num| num > 3) {
582+
println!("foo");
583+
}
584+
}
585+
"#,
586+
);
587+
}
258588
}

crates/ide-assists/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ mod handlers {
226226
add_return_type::add_return_type,
227227
add_turbo_fish::add_turbo_fish,
228228
apply_demorgan::apply_demorgan,
229+
apply_demorgan::apply_demorgan_iterator,
229230
auto_import::auto_import,
230231
bind_unused_param::bind_unused_param,
231232
bool_to_enum::bool_to_enum,

crates/ide-assists/src/tests/generated.rs

+24
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,30 @@ fn main() {
244244
)
245245
}
246246

247+
#[test]
248+
fn doctest_apply_demorgan_iterator() {
249+
check_doc_test(
250+
"apply_demorgan_iterator",
251+
r#####"
252+
//- minicore: iterator
253+
fn main() {
254+
let arr = [1, 2, 3];
255+
if !arr.into_iter().$0any(|num| num == 4) {
256+
println!("foo");
257+
}
258+
}
259+
"#####,
260+
r#####"
261+
fn main() {
262+
let arr = [1, 2, 3];
263+
if arr.into_iter().all(|num| num != 4) {
264+
println!("foo");
265+
}
266+
}
267+
"#####,
268+
)
269+
}
270+
247271
#[test]
248272
fn doctest_auto_import() {
249273
check_doc_test(

0 commit comments

Comments
 (0)