Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type narrowing on assignment #622

Merged
merged 1 commit into from
Jul 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions lib/steep/type_construction.rb
Original file line number Diff line number Diff line change
Expand Up @@ -1844,8 +1844,21 @@ def synthesize(node, hint: nil, condition: false)
if cond
branch_results = []

cond_type, constr = constr.synthesize(cond).to_ary
cond_value_node, cond_vars = interpreter.decompose_value(cond)
cond_type, constr = constr.synthesize(cond)
_, cond_vars = interpreter.decompose_value(cond)

var_name = :"_a#{SecureRandom.base64(4)}"
var_cond, value_node = extract_outermost_call(cond, var_name)
if value_node
unless constr.context.type_env[value_node]
constr = constr.update_type_env do |env|
env.assign_local_variable(var_name, cond_type, nil)
end
cond = var_cond
else
value_node = nil
end
end

when_constr = constr
whens.each do |clause|
Expand Down Expand Up @@ -1893,7 +1906,11 @@ def synthesize(node, hint: nil, condition: false)
types = branch_results.map(&:type)
constrs = branch_results.map(&:constr)

cond_type = when_constr.context.type_env[cond_value_node]
cond_type = when_constr.context.type_env[var_name]
cond_type ||= when_constr.context.type_env[cond_vars.first || raise] unless cond_vars.empty?
cond_type ||= when_constr.context.type_env[value_node] if value_node
cond_type ||= typing.type_of(node: node.children[0])

if cond_type.is_a?(AST::Types::Bot)
# Exhaustive
if els
Expand Down Expand Up @@ -4235,5 +4252,35 @@ def save_typing
typing.save!
with_new_typing(typing.parent)
end

def extract_outermost_call(node, var_name)
case node.type
when :lvasgn
name, rhs = node.children
rhs, value_node = extract_outermost_call(rhs, var_name)
if value_node
[node.updated(nil, [name, rhs]), value_node]
else
[node, value_node]
end
when :begin
*children, last = node.children
last, value_node = extract_outermost_call(last, var_name)
if value_node
[node.updated(nil, children.push(last)), value_node]
else
[node, value_node]
end
when :lvar
[node, nil]
else
if value_node?(node)
[node, nil]
else
var_node = node.updated(:lvar, [var_name])
[var_node, node]
end
end
end
end
end
2 changes: 2 additions & 0 deletions sig/steep/node_helper.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ module Steep
def each_descendant_node: (Parser::AST::Node) -> Enumerator[Parser::AST::Node, void]
| (Parser::AST::Node) { (Parser::AST::Node) -> void } -> void

# Returns true if given node is a syntactic-value node
#
def value_node?: (Parser::AST::Node) -> bool
end
end
13 changes: 13 additions & 0 deletions sig/steep/type_construction.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -283,5 +283,18 @@ module Steep
# * All of the arguments are _pure_
#
def pure_send?: (TypeInference::MethodCall::Typed call, Parser::AST::Node receiver, Array[Parser::AST::Node] arguments) -> bool

# Transform given `node` to a node that has a local variable instead of the outer most call/non-value node
#
# Returns a pair of transformed node and the outer most call/non-value node if present.
#
# ```rb
# x = y = foo() # Call `#transform_value_node` with the node and var_name `:__foo__`
# # => Returns [x = y = __foo__, foo()]
# ```
#
# This is typically used for transforming assginment node for case condition.
#
def extract_outermost_call: (Parser::AST::Node node, Symbol) -> [Parser::AST::Node, Parser::AST::Node?]
end
end
24 changes: 24 additions & 0 deletions test/type_construction_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9758,12 +9758,36 @@ def test_lvar_special
end

-> (_) { 123 }
RUBY
with_standard_construction(checker, source) do |construction, typing|
type, _, context = construction.synthesize(source.node)

assert_no_error typing
end
end
end

def test_type_narrowing_assignment
with_checker(<<-RBS) do |checker|
class NarrowingAssignmentTest
def foo: () -> (Integer | String)
end
RBS
source = parse_ruby(<<-RUBY)
case value = NarrowingAssignmentTest.new.foo()
when Integer
"hello"
when String
value
end
RUBY

with_standard_construction(checker, source) do |construction, typing|
type, _, context = construction.synthesize(source.node)

assert_no_error typing

assert_equal parse_type("::String"), type
end
end
end
Expand Down