Skip to content

Commit

Permalink
[REPLCompletions] improve implementation of completions
Browse files Browse the repository at this point in the history
- Restrict method completion to ignore strictly less specific ones
- Fix various lookup bugs
- Improve slurping of final expression

Inspired by #43572
Co-authored-by: Lionel Zoubritzky <lionel.zoubritzky@gmail.com>
  • Loading branch information
vtjnash committed Jan 19, 2022
1 parent 0ae0a5b commit d36db59
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 100 deletions.
169 changes: 79 additions & 90 deletions stdlib/REPL/src/REPLCompletions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ using Base: propertynames, something

abstract type Completion end

struct TextCompletion <: Completion
text::String
end

struct KeywordCompletion <: Completion
keyword::String
end
Expand Down Expand Up @@ -37,10 +41,7 @@ struct FieldCompletion <: Completion
end

struct MethodCompletion <: Completion
func
input_types::Type
method::Method
orig_method::Union{Nothing,Method} # if `method` is a keyword method, keep the original method for sensible printing
end

struct BslashCompletion <: Completion
Expand All @@ -58,7 +59,9 @@ end

# interface definition
function Base.getproperty(c::Completion, name::Symbol)
if name === :keyword
if name === :text
return getfield(c, :text)::String
elseif name === :keyword
return getfield(c, :keyword)::String
elseif name === :path
return getfield(c, :path)::String
Expand All @@ -84,13 +87,14 @@ function Base.getproperty(c::Completion, name::Symbol)
return getfield(c, name)
end

_completion_text(c::TextCompletion) = c.text
_completion_text(c::KeywordCompletion) = c.keyword
_completion_text(c::PathCompletion) = c.path
_completion_text(c::ModuleCompletion) = c.mod
_completion_text(c::PackageCompletion) = c.package
_completion_text(c::PropertyCompletion) = string(c.property)
_completion_text(c::FieldCompletion) = string(c.field)
_completion_text(c::MethodCompletion) = sprint(io -> show(io, isnothing(c.orig_method) ? c.method : c.orig_method::Method))
_completion_text(c::MethodCompletion) = repr(c.method)
_completion_text(c::BslashCompletion) = c.bslash
_completion_text(c::ShellCompletion) = c.text
_completion_text(c::DictCompletion) = c.key
Expand Down Expand Up @@ -125,7 +129,7 @@ function filtered_mod_names(ffunc::Function, mod::Module, name::AbstractString,
end

# REPL Symbol Completions
function complete_symbol(sym::String, ffunc, context_module::Module=Main)
function complete_symbol(sym::String, @nospecialize(ffunc), context_module::Module=Main)
mod = context_module
name = sym

Expand Down Expand Up @@ -407,62 +411,48 @@ end
# will show it consist of Expr, QuoteNode's and Symbol's which all needs to
# be handled differently to iterate down to get the value of whitespace_chars.
function get_value(sym::Expr, fn)
if sym.head === :quote || sym.head === :inert
return sym.args[1], true
end
sym.head !== :. && return (nothing, false)
for ex in sym.args
ex, found = get_value(ex, fn)
!found && return (nothing, false)
fn, found = get_value(ex, fn)
!found && return (nothing, false)
end
return (fn, true)
end
get_value(sym::Symbol, fn) = isdefined(fn, sym) ? (getfield(fn, sym), true) : (nothing, false)
get_value(sym::QuoteNode, fn) = isdefined(fn, sym.value) ? (getfield(fn, sym.value), true) : (nothing, false)
get_value(sym::QuoteNode, fn) = (sym.value, true)
get_value(sym::GlobalRef, fn) = get_value(sym.name, sym.mod)
get_value(sym, fn) = (sym, true)

# Return the type of a getfield call expression
function get_type_getfield(ex::Expr, fn::Module)
length(ex.args) == 3 || return Any, false # should never happen, but just for safety
obj, x = ex.args[2:3]
fld, found = get_value(ex.args[3], fn)
fld isa Symbol || return Any, false
obj = ex.args[2]
objt, found = get_type(obj, fn)
objt isa DataType || return Any, false
found || return Any, false
if x isa QuoteNode
fld = x.value
elseif isexpr(x, :quote) || isexpr(x, :inert)
fld = x.args[1]
else
fld = nothing # we don't know how to get the value of variable `x` here
end
fld isa Symbol || return Any, false
objt isa DataType || return Any, false
hasfield(objt, fld) || return Any, false
return fieldtype(objt, fld), true
end

# Determines the return type with Base.return_types of a function call using the type information of the arguments.
function get_type_call(expr::Expr)
# Determines the return type with the Compiler of a function call using the type information of the arguments.
function get_type_call(expr::Expr, fn::Module)
f_name = expr.args[1]
# The if statement should find the f function. How f is found depends on how f is referenced
if isa(f_name, GlobalRef) && isconst(f_name.mod,f_name.name) && isdefined(f_name.mod,f_name.name)
ft = typeof(eval(f_name))
found = true
else
ft, found = get_type(f_name, Main)
end
f, found = get_type(f_name, fn)
found || return (Any, false) # If the function f is not found return Any.
args = Any[]
for ex in expr.args[2:end] # Find the type of the function arguments
typ, found = get_type(ex, Main)
for i in 2:length(expr.args) # Find the type of the function arguments
typ, found = get_type(expr.args[i], fn)
found ? push!(args, typ) : push!(args, Any)
end
# use _methods_by_ftype as the function is supplied as a type
world = Base.get_world_counter()
matches = Base._methods_by_ftype(Tuple{ft, args...}, -1, world)::Vector
length(matches) == 1 || return (Any, false)
match = first(matches)::Core.MethodMatch
# Typeinference
interp = Core.Compiler.NativeInterpreter()
return_type = Core.Compiler.typeinf_type(interp, match.method, match.spec_types, match.sparams)
return_type === nothing && return (Any, false)
return_type = Core.Compiler.return_type(Tuple{f, args...}, world)
return (return_type, true)
end

Expand All @@ -477,15 +467,15 @@ function try_get_type(sym::Expr, fn::Module)
if a1 === :getfield || a1 === GlobalRef(Core, :getfield)
return get_type_getfield(sym, fn)
end
return get_type_call(sym)
return get_type_call(sym, fn)
elseif sym.head === :thunk
thk = sym.args[1]
rt = ccall(:jl_infer_thunk, Any, (Any, Any), thk::Core.CodeInfo, fn)
rt !== Any && return (rt, true)
elseif sym.head === :ref
# some simple cases of `expand`
return try_get_type(Expr(:call, GlobalRef(Base, :getindex), sym.args...), fn)
elseif sym.head === :. && sym.args[2] isa QuoteNode # second check catches broadcasting
elseif sym.head === :. && sym.args[2] isa QuoteNode # second check catches broadcasting
return try_get_type(Expr(:call, GlobalRef(Core, :getfield), sym.args...), fn)
end
return (Any, false)
Expand Down Expand Up @@ -525,37 +515,52 @@ function get_type(T, found::Bool, default_any::Bool)
end

# Method completion on function call expression that look like :(max(1))
MAX_METHOD_COMPLETIONS = 40
function complete_methods(ex_org::Expr, context_module::Module=Main)
func, found = get_value(ex_org.args[1], context_module)::Tuple{Any,Bool}
!found && return Completion[]
out = Completion[]
funct, found = get_type(ex_org.args[1], context_module)::Tuple{Any,Bool}
!found && return out

args_ex, kwargs_ex = complete_methods_args(ex_org.args[2:end], ex_org, context_module, true, true)
push!(args_ex, Vararg{Any})
complete_methods!(out, funct, args_ex, kwargs_ex, MAX_METHOD_COMPLETIONS::Int)

out = Completion[]
complete_methods!(out, func, args_ex, kwargs_ex)
return out
end

MAX_ANY_METHOD_COMPLETIONS = 10
function complete_any_methods(ex_org::Expr, callee_module::Module, context_module::Module, moreargs::Bool, shift::Bool)
out = Completion[]
args_ex, kwargs_ex = try
# this may throw, since we set default_any to false
complete_methods_args(ex_org.args[2:end], ex_org, context_module, false, false)
catch
catch ex
ex isa ArgumentError || rethrow()
return out
end
moreargs && push!(args_ex, Vararg{Any})

seen = Base.IdSet()
for name in names(callee_module; all=true)
if !Base.isdeprecated(callee_module, name) && isdefined(callee_module, name)
func = getfield(callee_module, name)
if !isa(func, Module)
complete_methods!(out, func, args_ex, kwargs_ex, moreargs)
elseif callee_module === Main::Module && isa(func, Module)
funct = Core.Typeof(func)
if !in(funct, seen)
push!(seen, funct)
complete_methods!(out, funct, args_ex, kwargs_ex, MAX_ANY_METHOD_COMPLETIONS::Int)
end
elseif callee_module === Main && isa(func, Module)
callee_module2 = func
for name in names(callee_module2)
if isdefined(callee_module2, name)
if !Base.isdeprecated(callee_module2, name) && isdefined(callee_module2, name)
func = getfield(callee_module, name)
if !isa(func, Module)
complete_methods!(out, func, args_ex, kwargs_ex, moreargs)
funct = Core.Typeof(func)
if !in(funct, seen)
push!(seen, funct)
complete_methods!(out, funct, args_ex, kwargs_ex, MAX_ANY_METHOD_COMPLETIONS::Int)
end
end
end
end
Expand All @@ -566,7 +571,8 @@ function complete_any_methods(ex_org::Expr, callee_module::Module, context_modul
if !shift
# Filter out methods where all arguments are `Any`
filter!(out) do c
isa(c, REPLCompletions.MethodCompletion) || return true
isa(c, TextCompletion) && return false
isa(c, MethodCompletion) || return true
sig = Base.unwrap_unionall(c.method.sig)::DataType
return !all(T -> T === Any || T === Vararg{Any}, sig.parameters[2:end])
end
Expand All @@ -577,7 +583,7 @@ end

function complete_methods_args(funargs::Vector{Any}, ex_org::Expr, context_module::Module, default_any::Bool, allow_broadcasting::Bool)
args_ex = Any[]
kwargs_ex = Pair{Symbol,Any}[]
kwargs_ex = false
if allow_broadcasting && ex_org.head === :. && ex_org.args[2] isa Expr
# handle broadcasting, but only handle number of arguments instead of
# argument types
Expand All @@ -587,13 +593,11 @@ function complete_methods_args(funargs::Vector{Any}, ex_org::Expr, context_modul
else
for ex in funargs
if isexpr(ex, :parameters)
for x in ex.args
n, v = isexpr(x, :kw) ? (x.args...,) : (x, x)
push!(kwargs_ex, n => get_type(get_type(v, context_module)..., default_any))
if !isempty(ex.args)
kwargs_ex = true
end
elseif isexpr(ex, :kw)
n, v = (ex.args...,)
push!(kwargs_ex, n => get_type(get_type(v, context_module)..., default_any))
kwargs_ex = true
else
push!(args_ex, get_type(get_type(ex, context_module)..., default_any))
end
Expand All @@ -602,34 +606,18 @@ function complete_methods_args(funargs::Vector{Any}, ex_org::Expr, context_modul
return args_ex, kwargs_ex
end

function complete_methods!(out::Vector{Completion}, @nospecialize(func), args_ex::Vector{Any}, kwargs_ex::Vector{Pair{Symbol,Any}}, moreargs::Bool=true)
ml = methods(func)
function complete_methods!(out::Vector{Completion}, @nospecialize(funct), args_ex::Vector{Any}, kwargs_ex::Bool, max_method_completions::Int)
# Input types and number of arguments
if isempty(kwargs_ex)
t_in = Tuple{Core.Typeof(func), args_ex...}
na = length(t_in.parameters)::Int
orig_ml = fill(nothing, length(ml))
else
isdefined(ml.mt, :kwsorter) || return out
kwfunc = ml.mt.kwsorter
kwargt = NamedTuple{(first.(kwargs_ex)...,), Tuple{last.(kwargs_ex)...}}
t_in = Tuple{Core.Typeof(kwfunc), kwargt, Core.Typeof(func), args_ex...}
na = length(t_in.parameters)::Int
orig_ml = ml # this method is supposed to be used for printing
ml = methods(kwfunc)
func = kwfunc
end
if !moreargs
na = typemax(Int)
t_in = Tuple{funct, args_ex...}
m = Base._methods_by_ftype(t_in, nothing, max_method_completions, Base.get_world_counter(),
#=ambig=# true, Ref(typemin(UInt)), Ref(typemax(UInt)), Ptr{Int32}(C_NULL))
if m === false
push!(out, TextCompletion(sprint(Base.show_signature_function, funct) * "( too many methods to show )"))
end

for (method::Method, orig_method) in zip(ml, orig_ml)
ms = method.sig

# Check if the method's type signature intersects the input types
if typeintersect(Base.rewrap_unionall(Tuple{(Base.unwrap_unionall(ms)::DataType).parameters[1 : min(na, end)]...}, ms), t_in) != Union{}
push!(out, MethodCompletion(func, t_in, method, orig_method))
end
m isa Vector || return
for match in m
# TODO: if kwargs_ex, filter out methods without kwargs?
push!(out, MethodCompletion(match.method))
end
end

Expand Down Expand Up @@ -708,7 +696,7 @@ function bslash_completions(string::String, pos::Int)
return (false, (Completion[], 0:-1, false))
end

function dict_identifier_key(str::String, tag::Symbol, context_module::Module = Main)
function dict_identifier_key(str::String, tag::Symbol, context_module::Module=Main)
if tag === :string
str_close = str*"\""
elseif tag === :cmd
Expand Down Expand Up @@ -897,21 +885,22 @@ function completions(string::String, pos::Int, context_module::Module=Main, shif
dotpos < startpos && (dotpos = startpos - 1)
s = string[startpos:pos]
comp_keywords && append!(suggestions, complete_keyword(s))
# The case where dot and start pos is equal could look like: "(""*"").d","". or CompletionFoo.test_y_array[1].y
# This case can be handled by finding the beginning of the expression. This is done below.
if dotpos == startpos
# if the start of the string is a `.`, try to consume more input to get back to the beginning of the last expression
if 0 < startpos <= lastindex(string) && string[startpos] == '.'
i = prevind(string, startpos)
while 0 < i
c = string[i]
if c in [')', ']']
if c==')'
c_start='('; c_end=')'
elseif c==']'
c_start='['; c_end=']'
if c in (')', ']')
if c == ')'
c_start = '('
c_end = ')'
elseif c == ']'
c_start = '['
c_end = ']'
end
frange, end_of_identifier = find_start_brace(string[1:prevind(string, i)], c_start=c_start, c_end=c_end)
isempty(frange) && break # unbalanced parens
startpos = first(frange)
startpos == 0 && break
i = prevind(string, startpos)
elseif c in ('\'', '\"', '\`')
s = "$c$c"*string[startpos:pos]
Expand Down
Loading

0 comments on commit d36db59

Please sign in to comment.