|
1 | 1 | use std::collections::VecDeque;
|
2 | 2 |
|
| 3 | +use ide_db::{ |
| 4 | + assists::GroupLabel, |
| 5 | + famous_defs::FamousDefs, |
| 6 | + source_change::SourceChangeBuilder, |
| 7 | + syntax_helpers::node_ext::{for_each_tail_expr, walk_expr}, |
| 8 | +}; |
3 | 9 | use syntax::{
|
4 |
| - ast::{self, AstNode, Expr::BinExpr}, |
| 10 | + ast::{self, make, AstNode, Expr::BinExpr, HasArgList}, |
5 | 11 | ted::{self, Position},
|
6 | 12 | SyntaxKind,
|
7 | 13 | };
|
@@ -89,7 +95,8 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
|
89 | 95 |
|
90 | 96 | let dm_lhs = demorganed.lhs()?;
|
91 | 97 |
|
92 |
| - acc.add( |
| 98 | + acc.add_group( |
| 99 | + &GroupLabel("Apply De Morgan's law".to_string()), |
93 | 100 | AssistId("apply_demorgan", AssistKind::RefactorRewrite),
|
94 | 101 | "Apply De Morgan's law",
|
95 | 102 | op_range,
|
@@ -143,6 +150,127 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
|
143 | 150 | )
|
144 | 151 | }
|
145 | 152 |
|
| 153 | +// Assist: apply_demorgan_iterator |
| 154 | +// |
| 155 | +// Apply https://en.wikipedia.org/wiki/De_Morgan%27s_laws[De Morgan's law] to |
| 156 | +// `Iterator::all` and `Iterator::any`. |
| 157 | +// |
| 158 | +// This transforms expressions of the form `!iter.any(|x| predicate(x))` into |
| 159 | +// `iter.all(|x| !predicate(x))` and vice versa. This also works the other way for |
| 160 | +// `Iterator::all` into `Iterator::any`. |
| 161 | +// |
| 162 | +// ``` |
| 163 | +// # //- minicore: iterator |
| 164 | +// fn main() { |
| 165 | +// let arr = [1, 2, 3]; |
| 166 | +// if !arr.into_iter().$0any(|num| num == 4) { |
| 167 | +// println!("foo"); |
| 168 | +// } |
| 169 | +// } |
| 170 | +// ``` |
| 171 | +// -> |
| 172 | +// ``` |
| 173 | +// fn main() { |
| 174 | +// let arr = [1, 2, 3]; |
| 175 | +// if arr.into_iter().all(|num| num != 4) { |
| 176 | +// println!("foo"); |
| 177 | +// } |
| 178 | +// } |
| 179 | +// ``` |
| 180 | +pub(crate) fn apply_demorgan_iterator(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { |
| 181 | + let method_call: ast::MethodCallExpr = ctx.find_node_at_offset()?; |
| 182 | + let (name, arg_expr) = validate_method_call_expr(ctx, &method_call)?; |
| 183 | + |
| 184 | + let ast::Expr::ClosureExpr(closure_expr) = arg_expr else { return None }; |
| 185 | + let closure_body = closure_expr.body()?; |
| 186 | + |
| 187 | + let op_range = method_call.syntax().text_range(); |
| 188 | + let label = format!("Apply De Morgan's law to `Iterator::{}`", name.text().as_str()); |
| 189 | + acc.add_group( |
| 190 | + &GroupLabel("Apply De Morgan's law".to_string()), |
| 191 | + AssistId("apply_demorgan_iterator", AssistKind::RefactorRewrite), |
| 192 | + label, |
| 193 | + op_range, |
| 194 | + |edit| { |
| 195 | + // replace the method name |
| 196 | + let new_name = match name.text().as_str() { |
| 197 | + "all" => make::name_ref("any"), |
| 198 | + "any" => make::name_ref("all"), |
| 199 | + _ => unreachable!(), |
| 200 | + } |
| 201 | + .clone_for_update(); |
| 202 | + edit.replace_ast(name, new_name); |
| 203 | + |
| 204 | + // negate all tail expressions in the closure body |
| 205 | + let tail_cb = &mut |e: &_| tail_cb_impl(edit, e); |
| 206 | + walk_expr(&closure_body, &mut |expr| { |
| 207 | + if let ast::Expr::ReturnExpr(ret_expr) = expr { |
| 208 | + if let Some(ret_expr_arg) = &ret_expr.expr() { |
| 209 | + for_each_tail_expr(ret_expr_arg, tail_cb); |
| 210 | + } |
| 211 | + } |
| 212 | + }); |
| 213 | + for_each_tail_expr(&closure_body, tail_cb); |
| 214 | + |
| 215 | + // negate the whole method call |
| 216 | + if let Some(prefix_expr) = method_call |
| 217 | + .syntax() |
| 218 | + .parent() |
| 219 | + .and_then(ast::PrefixExpr::cast) |
| 220 | + .filter(|prefix_expr| matches!(prefix_expr.op_kind(), Some(ast::UnaryOp::Not))) |
| 221 | + { |
| 222 | + edit.delete( |
| 223 | + prefix_expr |
| 224 | + .op_token() |
| 225 | + .expect("prefix expression always has an operator") |
| 226 | + .text_range(), |
| 227 | + ); |
| 228 | + } else { |
| 229 | + edit.insert(method_call.syntax().text_range().start(), "!"); |
| 230 | + } |
| 231 | + }, |
| 232 | + ) |
| 233 | +} |
| 234 | + |
| 235 | +/// Ensures that the method call is to `Iterator::all` or `Iterator::any`. |
| 236 | +fn validate_method_call_expr( |
| 237 | + ctx: &AssistContext<'_>, |
| 238 | + method_call: &ast::MethodCallExpr, |
| 239 | +) -> Option<(ast::NameRef, ast::Expr)> { |
| 240 | + let name_ref = method_call.name_ref()?; |
| 241 | + if name_ref.text() != "all" && name_ref.text() != "any" { |
| 242 | + return None; |
| 243 | + } |
| 244 | + let arg_expr = method_call.arg_list()?.args().next()?; |
| 245 | + |
| 246 | + let sema = &ctx.sema; |
| 247 | + |
| 248 | + let receiver = method_call.receiver()?; |
| 249 | + let it_type = sema.type_of_expr(&receiver)?.adjusted(); |
| 250 | + let module = sema.scope(receiver.syntax())?.module(); |
| 251 | + let krate = module.krate(); |
| 252 | + |
| 253 | + let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?; |
| 254 | + it_type.impls_trait(sema.db, iter_trait, &[]).then_some((name_ref, arg_expr)) |
| 255 | +} |
| 256 | + |
| 257 | +fn tail_cb_impl(edit: &mut SourceChangeBuilder, e: &ast::Expr) { |
| 258 | + match e { |
| 259 | + ast::Expr::BreakExpr(break_expr) => { |
| 260 | + if let Some(break_expr_arg) = break_expr.expr() { |
| 261 | + for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(edit, e)) |
| 262 | + } |
| 263 | + } |
| 264 | + ast::Expr::ReturnExpr(_) => { |
| 265 | + // all return expressions have already been handled by the walk loop |
| 266 | + } |
| 267 | + e => { |
| 268 | + let inverted_body = invert_boolean_expression(e.clone()); |
| 269 | + edit.replace(e.syntax().text_range(), inverted_body.syntax().text()); |
| 270 | + } |
| 271 | + } |
| 272 | +} |
| 273 | + |
146 | 274 | #[cfg(test)]
|
147 | 275 | mod tests {
|
148 | 276 | use super::*;
|
@@ -255,4 +383,206 @@ fn f() { !(S <= S || S < S) }
|
255 | 383 | "fn() { let x = a && b && c; }",
|
256 | 384 | )
|
257 | 385 | }
|
| 386 | + |
| 387 | + #[test] |
| 388 | + fn demorgan_iterator_any_all_reverse() { |
| 389 | + check_assist( |
| 390 | + apply_demorgan_iterator, |
| 391 | + r#" |
| 392 | +//- minicore: iterator |
| 393 | +fn main() { |
| 394 | + let arr = [1, 2, 3]; |
| 395 | + if arr.into_iter().all(|num| num $0!= 4) { |
| 396 | + println!("foo"); |
| 397 | + } |
| 398 | +} |
| 399 | +"#, |
| 400 | + r#" |
| 401 | +fn main() { |
| 402 | + let arr = [1, 2, 3]; |
| 403 | + if !arr.into_iter().any(|num| num == 4) { |
| 404 | + println!("foo"); |
| 405 | + } |
| 406 | +} |
| 407 | +"#, |
| 408 | + ); |
| 409 | + } |
| 410 | + |
| 411 | + #[test] |
| 412 | + fn demorgan_iterator_all_any() { |
| 413 | + check_assist( |
| 414 | + apply_demorgan_iterator, |
| 415 | + r#" |
| 416 | +//- minicore: iterator |
| 417 | +fn main() { |
| 418 | + let arr = [1, 2, 3]; |
| 419 | + if !arr.into_iter().$0all(|num| num > 3) { |
| 420 | + println!("foo"); |
| 421 | + } |
| 422 | +} |
| 423 | +"#, |
| 424 | + r#" |
| 425 | +fn main() { |
| 426 | + let arr = [1, 2, 3]; |
| 427 | + if arr.into_iter().any(|num| num <= 3) { |
| 428 | + println!("foo"); |
| 429 | + } |
| 430 | +} |
| 431 | +"#, |
| 432 | + ); |
| 433 | + } |
| 434 | + |
| 435 | + #[test] |
| 436 | + fn demorgan_iterator_multiple_terms() { |
| 437 | + check_assist( |
| 438 | + apply_demorgan_iterator, |
| 439 | + r#" |
| 440 | +//- minicore: iterator |
| 441 | +fn main() { |
| 442 | + let arr = [1, 2, 3]; |
| 443 | + if !arr.into_iter().$0any(|num| num > 3 && num == 23 && num <= 30) { |
| 444 | + println!("foo"); |
| 445 | + } |
| 446 | +} |
| 447 | +"#, |
| 448 | + r#" |
| 449 | +fn main() { |
| 450 | + let arr = [1, 2, 3]; |
| 451 | + if arr.into_iter().all(|num| !(num > 3 && num == 23 && num <= 30)) { |
| 452 | + println!("foo"); |
| 453 | + } |
| 454 | +} |
| 455 | +"#, |
| 456 | + ); |
| 457 | + } |
| 458 | + |
| 459 | + #[test] |
| 460 | + fn demorgan_iterator_double_negation() { |
| 461 | + check_assist( |
| 462 | + apply_demorgan_iterator, |
| 463 | + r#" |
| 464 | +//- minicore: iterator |
| 465 | +fn main() { |
| 466 | + let arr = [1, 2, 3]; |
| 467 | + if !arr.into_iter().$0all(|num| !(num > 3)) { |
| 468 | + println!("foo"); |
| 469 | + } |
| 470 | +} |
| 471 | +"#, |
| 472 | + r#" |
| 473 | +fn main() { |
| 474 | + let arr = [1, 2, 3]; |
| 475 | + if arr.into_iter().any(|num| num > 3) { |
| 476 | + println!("foo"); |
| 477 | + } |
| 478 | +} |
| 479 | +"#, |
| 480 | + ); |
| 481 | + } |
| 482 | + |
| 483 | + #[test] |
| 484 | + fn demorgan_iterator_double_parens() { |
| 485 | + check_assist( |
| 486 | + apply_demorgan_iterator, |
| 487 | + r#" |
| 488 | +//- minicore: iterator |
| 489 | +fn main() { |
| 490 | + let arr = [1, 2, 3]; |
| 491 | + if !arr.into_iter().$0any(|num| (num > 3 && (num == 1 || num == 2))) { |
| 492 | + println!("foo"); |
| 493 | + } |
| 494 | +} |
| 495 | +"#, |
| 496 | + r#" |
| 497 | +fn main() { |
| 498 | + let arr = [1, 2, 3]; |
| 499 | + if arr.into_iter().all(|num| !(num > 3 && (num == 1 || num == 2))) { |
| 500 | + println!("foo"); |
| 501 | + } |
| 502 | +} |
| 503 | +"#, |
| 504 | + ); |
| 505 | + } |
| 506 | + |
| 507 | + #[test] |
| 508 | + fn demorgan_iterator_multiline() { |
| 509 | + check_assist( |
| 510 | + apply_demorgan_iterator, |
| 511 | + r#" |
| 512 | +//- minicore: iterator |
| 513 | +fn main() { |
| 514 | + let arr = [1, 2, 3]; |
| 515 | + if arr |
| 516 | + .into_iter() |
| 517 | + .all$0(|num| !num.is_negative()) |
| 518 | + { |
| 519 | + println!("foo"); |
| 520 | + } |
| 521 | +} |
| 522 | +"#, |
| 523 | + r#" |
| 524 | +fn main() { |
| 525 | + let arr = [1, 2, 3]; |
| 526 | + if !arr |
| 527 | + .into_iter() |
| 528 | + .any(|num| num.is_negative()) |
| 529 | + { |
| 530 | + println!("foo"); |
| 531 | + } |
| 532 | +} |
| 533 | +"#, |
| 534 | + ); |
| 535 | + } |
| 536 | + |
| 537 | + #[test] |
| 538 | + fn demorgan_iterator_block_closure() { |
| 539 | + check_assist( |
| 540 | + apply_demorgan_iterator, |
| 541 | + r#" |
| 542 | +//- minicore: iterator |
| 543 | +fn main() { |
| 544 | + let arr = [-1, 1, 2, 3]; |
| 545 | + if arr.into_iter().all(|num: i32| { |
| 546 | + $0if num.is_positive() { |
| 547 | + num <= 3 |
| 548 | + } else { |
| 549 | + num >= -1 |
| 550 | + } |
| 551 | + }) { |
| 552 | + println!("foo"); |
| 553 | + } |
| 554 | +} |
| 555 | +"#, |
| 556 | + r#" |
| 557 | +fn main() { |
| 558 | + let arr = [-1, 1, 2, 3]; |
| 559 | + if !arr.into_iter().any(|num: i32| { |
| 560 | + if num.is_positive() { |
| 561 | + num > 3 |
| 562 | + } else { |
| 563 | + num < -1 |
| 564 | + } |
| 565 | + }) { |
| 566 | + println!("foo"); |
| 567 | + } |
| 568 | +} |
| 569 | +"#, |
| 570 | + ); |
| 571 | + } |
| 572 | + |
| 573 | + #[test] |
| 574 | + fn demorgan_iterator_wrong_method() { |
| 575 | + check_assist_not_applicable( |
| 576 | + apply_demorgan_iterator, |
| 577 | + r#" |
| 578 | +//- minicore: iterator |
| 579 | +fn main() { |
| 580 | + let arr = [1, 2, 3]; |
| 581 | + if !arr.into_iter().$0map(|num| num > 3) { |
| 582 | + println!("foo"); |
| 583 | + } |
| 584 | +} |
| 585 | +"#, |
| 586 | + ); |
| 587 | + } |
258 | 588 | }
|
0 commit comments