Skip to content

Commit

Permalink
Code action: Expand catch all variant (rescript-lang#987)
Browse files Browse the repository at this point in the history
* code action for expanding catch all with variants

* make work with polyvariants

* extend to work on options

* changelog + fix
  • Loading branch information
zth authored and jfrolich committed Sep 3, 2024
1 parent ada97ac commit affe4ac
Show file tree
Hide file tree
Showing 5 changed files with 340 additions and 48 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
- Emit `%todo` instead of `failwith("TODO")` when we can (ReScript >= v11.1). https://github.com/rescript-lang/rescript-vscode/pull/981
- Complete `%todo`. https://github.com/rescript-lang/rescript-vscode/pull/981
- Add code action for extracting a locally defined module into its own file. https://github.com/rescript-lang/rescript-vscode/pull/983
- Add code action for expanding catch-all patterns. https://github.com/rescript-lang/rescript-vscode/pull/987

## 1.50.0

Expand Down
4 changes: 3 additions & 1 deletion analysis/src/CompletionFrontEnd.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,9 @@ let completionWithParser1 ~currentFile ~debug ~offset ~path ~posCursor
typedCompletionExpr expr;
match expr.pexp_desc with
| Pexp_match (expr, cases)
when cases <> [] && locHasCursor expr.pexp_loc = false ->
when cases <> []
&& locHasCursor expr.pexp_loc = false
&& Option.is_none findThisExprLoc ->
if Debug.verbose () then
print_endline "[completionFrontend] Checking each case";
let ctxPath = exprToContextPath expr in
Expand Down
272 changes: 230 additions & 42 deletions analysis/src/Xform.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,42 @@ let rangeOfLoc (loc : Location.t) =
let end_ = loc |> Loc.end_ |> mkPosition in
{Protocol.start; end_}

let extractTypeFromExpr expr ~debug ~path ~currentFile ~full ~pos =
match
expr.Parsetree.pexp_loc
|> CompletionFrontEnd.findTypeOfExpressionAtLoc ~debug ~path ~currentFile
~posCursor:(Pos.ofLexing expr.Parsetree.pexp_loc.loc_start)
with
| Some (completable, scope) -> (
let env = SharedTypes.QueryEnv.fromFile full.SharedTypes.file in
let completions =
completable
|> CompletionBackEnd.processCompletable ~debug ~full ~pos ~scope ~env
~forHover:true
in
let rawOpens = Scope.getRawOpens scope in
match completions with
| {env} :: _ -> (
let opens =
CompletionBackEnd.getOpens ~debug ~rawOpens ~package:full.package ~env
in
match
CompletionBackEnd.completionsGetCompletionType2 ~debug ~full ~rawOpens
~opens ~pos completions
with
| Some (typ, _env) ->
let extractedType =
match typ with
| ExtractedType t -> Some t
| TypeExpr t ->
TypeUtils.extractType t ~env ~package:full.package
|> TypeUtils.getExtractedType
in
extractedType
| None -> None)
| _ -> None)
| _ -> None

module IfThenElse = struct
(* Convert if-then-else to switch *)

Expand Down Expand Up @@ -324,6 +360,196 @@ module AddTypeAnnotation = struct
| _ -> ()))
end

module ExpandCatchAllForVariants = struct
let mkIterator ~pos ~result =
let expr (iterator : Ast_iterator.iterator) (e : Parsetree.expression) =
(if e.pexp_loc |> Loc.hasPos ~pos then
match e.pexp_desc with
| Pexp_match (switchExpr, cases) -> (
let catchAllCase =
cases
|> List.find_opt (fun (c : Parsetree.case) ->
match c with
| {pc_lhs = {ppat_desc = Ppat_any}} -> true
| _ -> false)
in
match catchAllCase with
| None -> ()
| Some catchAllCase ->
result := Some (switchExpr, catchAllCase, cases))
| _ -> ());
Ast_iterator.default_iterator.expr iterator e
in
{Ast_iterator.default_iterator with expr}

let xform ~path ~pos ~full ~structure ~currentFile ~codeActions ~debug =
let result = ref None in
let iterator = mkIterator ~pos ~result in
iterator.structure iterator structure;
match !result with
| None -> ()
| Some (switchExpr, catchAllCase, cases) -> (
if Debug.verbose () then
print_endline
"[codeAction - ExpandCatchAllForVariants] Found target switch";
let currentConstructorNames =
cases
|> List.filter_map (fun (c : Parsetree.case) ->
match c with
| {pc_lhs = {ppat_desc = Ppat_construct ({txt}, _)}} ->
Some (Longident.last txt)
| {pc_lhs = {ppat_desc = Ppat_variant (name, _)}} -> Some name
| _ -> None)
in
match
switchExpr
|> extractTypeFromExpr ~debug ~path ~currentFile ~full
~pos:(Pos.ofLexing switchExpr.pexp_loc.loc_end)
with
| Some (Tvariant {constructors}) ->
let missingConstructors =
constructors
|> List.filter (fun (c : SharedTypes.Constructor.t) ->
currentConstructorNames |> List.mem c.cname.txt = false)
in
if List.length missingConstructors > 0 then
let newText =
missingConstructors
|> List.map (fun (c : SharedTypes.Constructor.t) ->
c.cname.txt
^
match c.args with
| Args [] -> ""
| Args _ | InlineRecord _ -> "(_)")
|> String.concat " | "
in
let range = rangeOfLoc catchAllCase.pc_lhs.ppat_loc in
let codeAction =
CodeActions.make ~title:"Expand catch-all" ~kind:RefactorRewrite
~uri:path ~newText ~range
in
codeActions := codeAction :: !codeActions
else ()
| Some (Tpolyvariant {constructors}) ->
let missingConstructors =
constructors
|> List.filter (fun (c : SharedTypes.polyVariantConstructor) ->
currentConstructorNames |> List.mem c.name = false)
in
if List.length missingConstructors > 0 then
let newText =
missingConstructors
|> List.map (fun (c : SharedTypes.polyVariantConstructor) ->
Res_printer.polyVarIdentToString c.name
^
match c.args with
| [] -> ""
| _ -> "(_)")
|> String.concat " | "
in
let range = rangeOfLoc catchAllCase.pc_lhs.ppat_loc in
let codeAction =
CodeActions.make ~title:"Expand catch-all" ~kind:RefactorRewrite
~uri:path ~newText ~range
in
codeActions := codeAction :: !codeActions
else ()
| Some (Toption (env, innerType)) -> (
if Debug.verbose () then
print_endline
"[codeAction - ExpandCatchAllForVariants] Found option type";
let innerType =
match innerType with
| ExtractedType t -> Some t
| TypeExpr t -> (
match TypeUtils.extractType ~env ~package:full.package t with
| None -> None
| Some (t, _) -> Some t)
in
match innerType with
| Some ((Tvariant _ | Tpolyvariant _) as variant) ->
let currentConstructorNames =
cases
|> List.filter_map (fun (c : Parsetree.case) ->
match c with
| {
pc_lhs =
{
ppat_desc =
Ppat_construct
( {txt = Lident "Some"},
Some {ppat_desc = Ppat_construct ({txt}, _)} );
};
} ->
Some (Longident.last txt)
| {
pc_lhs =
{
ppat_desc =
Ppat_construct
( {txt = Lident "Some"},
Some {ppat_desc = Ppat_variant (name, _)} );
};
} ->
Some name
| _ -> None)
in
let hasNoneCase =
cases
|> List.exists (fun (c : Parsetree.case) ->
match c.pc_lhs.ppat_desc with
| Ppat_construct ({txt = Lident "None"}, _) -> true
| _ -> false)
in
let missingConstructors =
match variant with
| Tvariant {constructors} ->
constructors
|> List.filter_map (fun (c : SharedTypes.Constructor.t) ->
if currentConstructorNames |> List.mem c.cname.txt = false
then
Some
( c.cname.txt,
match c.args with
| Args [] -> false
| _ -> true )
else None)
| Tpolyvariant {constructors} ->
constructors
|> List.filter_map
(fun (c : SharedTypes.polyVariantConstructor) ->
if currentConstructorNames |> List.mem c.name = false then
Some
( Res_printer.polyVarIdentToString c.name,
match c.args with
| [] -> false
| _ -> true )
else None)
| _ -> []
in
if List.length missingConstructors > 0 || not hasNoneCase then
let newText =
"Some("
^ (missingConstructors
|> List.map (fun (name, hasArgs) ->
name ^ if hasArgs then "(_)" else "")
|> String.concat " | ")
^ ")"
in
let newText =
if hasNoneCase then newText else newText ^ " | None"
in
let range = rangeOfLoc catchAllCase.pc_lhs.ppat_loc in
let codeAction =
CodeActions.make ~title:"Expand catch-all" ~kind:RefactorRewrite
~uri:path ~newText ~range
in
codeActions := codeAction :: !codeActions
else ()
| _ -> ())
| _ -> ())
end

module ExhaustiveSwitch = struct
(* Expand expression to be an exhaustive switch of the underlying value *)
type posType = Single of Pos.t | Range of Pos.t * Pos.t
Expand All @@ -336,46 +562,6 @@ module ExhaustiveSwitch = struct
}
| Selection of {expr: Parsetree.expression}

module C = struct
let extractTypeFromExpr expr ~debug ~path ~currentFile ~full ~pos =
match
expr.Parsetree.pexp_loc
|> CompletionFrontEnd.findTypeOfExpressionAtLoc ~debug ~path
~currentFile
~posCursor:(Pos.ofLexing expr.Parsetree.pexp_loc.loc_start)
with
| Some (completable, scope) -> (
let env = SharedTypes.QueryEnv.fromFile full.SharedTypes.file in
let completions =
completable
|> CompletionBackEnd.processCompletable ~debug ~full ~pos ~scope ~env
~forHover:true
in
let rawOpens = Scope.getRawOpens scope in
match completions with
| {env} :: _ -> (
let opens =
CompletionBackEnd.getOpens ~debug ~rawOpens ~package:full.package
~env
in
match
CompletionBackEnd.completionsGetCompletionType2 ~debug ~full
~rawOpens ~opens ~pos completions
with
| Some (typ, _env) ->
let extractedType =
match typ with
| ExtractedType t -> Some t
| TypeExpr t ->
TypeUtils.extractType t ~env ~package:full.package
|> TypeUtils.getExtractedType
in
extractedType
| None -> None)
| _ -> None)
| _ -> None
end

let mkIteratorSingle ~pos ~result =
let expr (iterator : Ast_iterator.iterator) (exp : Parsetree.expression) =
(match exp.pexp_desc with
Expand Down Expand Up @@ -434,7 +620,7 @@ module ExhaustiveSwitch = struct
| Some (Selection {expr}) -> (
match
expr
|> C.extractTypeFromExpr ~debug ~path ~currentFile ~full
|> extractTypeFromExpr ~debug ~path ~currentFile ~full
~pos:(Pos.ofLexing expr.pexp_loc.loc_start)
with
| None -> ()
Expand All @@ -460,7 +646,7 @@ module ExhaustiveSwitch = struct
| Some (Switch {switchExpr; completionExpr; pos}) -> (
match
completionExpr
|> C.extractTypeFromExpr ~debug ~path ~currentFile ~full ~pos
|> extractTypeFromExpr ~debug ~path ~currentFile ~full ~pos
with
| None -> ()
| Some extractedType -> (
Expand Down Expand Up @@ -743,6 +929,8 @@ let extractCodeActions ~path ~startPos ~endPos ~currentFile ~debug =
match Cmt.loadFullCmtFromPath ~path with
| Some full ->
AddTypeAnnotation.xform ~path ~pos ~full ~structure ~codeActions ~debug;
ExpandCatchAllForVariants.xform ~path ~pos ~full ~structure ~codeActions
~currentFile ~debug;
ExhaustiveSwitch.xform ~printExpr ~path
~pos:
(if startPos = endPos then Single startPos
Expand Down
43 changes: 38 additions & 5 deletions analysis/tests/src/Xform.res
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
type kind = First | Second | Third
type kind = First | Second | Third | Fourth(int)
type r = {name: string, age: int}

let ret = _ => assert false
let kind = assert false
let ret = _ => assert(false)
let kind = assert(false)

if kind == First {
// ^xfm
Expand Down Expand Up @@ -63,7 +63,7 @@ let bar = () => {
}
//^xfm
}
@res.partial Inner.foo(1)
Inner.foo(1, ...)
}

module ExtractableModule = {
Expand All @@ -72,4 +72,37 @@ module ExtractableModule = {
// A comment here
let doStuff = a => a + 1
// ^xfm
}
}

let variant = First

let _x = switch variant {
| First => "first"
| _ => "other"
// ^xfm
}

let polyvariant: [#first | #second | #"illegal identifier" | #third(int)] = #first

let _y = switch polyvariant {
| #first => "first"
| _ => "other"
// ^xfm
}

let variantOpt = Some(variant)

let _x = switch variantOpt {
| Some(First) => "first"
| _ => "other"
// ^xfm
}

let polyvariantOpt = Some(polyvariant)

let _x = switch polyvariantOpt {
| Some(#first) => "first"
| None => "nothing"
| _ => "other"
// ^xfm
}
Loading

0 comments on commit affe4ac

Please sign in to comment.